国内大佬们写的很难理解,找了个外国友人的文章,一下就看懂了。本文参考:
geeksforgeeks基础线段树
geeksforgeeks懒标记区间更新
要掌握线段树,得一步一步来。一上来就lazytag,很难理解。
一、普通单点修改
如果修改的单点属于当前树上节点覆盖的范围,直接改,然后改左右子树。没有什么pushup和pushdown。
//ss、se分别是当前树上节点覆盖范围开始和结束下标
//si是树上元素在树的数组里的下标,i是原数组下标,diff是加多少
//调用的时候从根开始update(1,n,1,5,20)
void update(int ss, int se,, int si, int i, int diff)
{
if (i < ss || i > se)
return;
st[si] = st[si] + diff;
if (se != ss)
{
int mid = getMid(ss, se);
update(ss, mid, 2*si, i, diff);
update(mid+1, se,,2*si+1, i, diff);
}
}
二、普通区间修改
区间修改,先看树上节点覆盖的范围和修改的范围有没有交集,没有就什么都不干;有的话分两种情况,一是到了叶子,直接更新;二是没到叶子节点,又分两种情况,1是节点覆盖范围被修改范围完全覆盖;2是不完全覆盖,不管哪种情况,做法都一样,直接更新左右子树,更新完以后,重新计算左右子树的值,更新当前节点值。也没有什么pushup和pushdown。
//us和ue分别是更新区间的下标开始、结束
void update(int ss, int se, int si, int us, int ue, int diff){
if (ss > ue || se < us) return;
if(ss == se){
st[si].v = st[si].v + diff;
return;
}
int mid = getmid(ss, se);
update(ss, mid, si * 2, us, ue, diff);
update(mid+1, se, si * 2 + 1, us, ue, diff);
st[si].v = st[si*2].v + st[si*2+1].v;
}
仔细体会这种方式,类似深度优先遍历,从根直接到叶子节点,叶子节点更新完成后,一层一层往上更新中间节点,最后更新根。
三、懒标记区间修改
暴力区间修改太慢了,最坏情况下,如果更新整个数组,复杂度O(nlogn),比直接在原数组上更新还慢,所以必须改进。
改进办法是加入懒标记,首先必须明确最重要的一点,当一个树上节点覆盖范围完全被更新区间包含时,这个节点和所有这个节点的子孙都需要更新;反之如果一个树上节点覆盖范围和更区间部分重合,则肯定有一部分子孙需要更新,另一部分绝不需要更新。我们的做法是,该更新还是更新,直接更新就行((se-ss+1)x diff),而不像上面暴力更新那样,深度优先到叶子上,从叶子一层一层往上更新。直接更新完以后,给子孙设置懒标记,被设置懒标记的节点,先不要动,等以后更新或者查询的时候,再处理。
一个节点的懒标记,延迟的是这个节点和它的所有子孙的更新。当一个节点遇到更新和查询操作时,有懒标记的话就先消化懒标记,然后把懒标记下传(也就是他们说的pushdown)给子孙,最后正常更新。
更新完一个节点后,也需要下传懒标记,停止更新进程,把子孙的更新推迟。
两种情况需要下传懒标记,一是自己消化懒标记时,二是自己更新时。下传懒标记的时候注意判断自己是不是叶子,不是才下传,是的话下传就数组越界了。
总之,懒标记是爸爸给他的,不是自己给自己的。懒标记的消化,在更新和查询操作中。懒标记消化分3步:更新自己、传给儿子、还原初始状态(还原或清零)。
举个例子,首先更新1-3,有个节点覆盖1-3,先把它更新,懒标记下传给1-2的爸爸,和3,结束。这时要查询2-4,需要查询2和3,这两个节点上都有懒标记,先消化,再返回。
看代码:
//洛谷p3373线段树模板2
#include <cstdio>
#define MAXN 100000
typedef long long ll;
using namespace std;
//线段树节点,v表示值,lza加法懒标记,lzm乘法懒标记
struct node {
ll v, lza, lzm;
} st[MAXN*4+1];
int a[MAXN+1];
int n, m, p;
inline int getmid(int s, int e){
return s + (e - s) / 2;
}
inline int left(int si){
return si * 2;
}
inline int right(int si){
return si * 2 + 1;
}
ll build(int ss, int se, int si){
st[si].lzm = 1;
if (ss == se) {
return st[si].v = a[ss] % p;
}
int mid = getmid(ss, se);
return st[si].v = (build(ss, mid, si * 2) + build(mid+1, se, si * 2 + 1)) % p;
}
void update(int ss, int se, int si, int us, int ue, int op, int opt){
if (st[si].lzm != 1){
st[si].v = st[si].v * st[si].lzm % p;//消化
if(ss != se){//下传
st[left(si)].lzm = st[left(si)].lzm * st[si].lzm % p;
st[left(si)].lza = st[left(si)].lza * st[si].lzm % p;
st[right(si)].lzm = st[right(si)].lzm * st[si].lzm % p;
st[right(si)].lza = st[right(si)].lza * st[si].lzm % p;
}
st[si].lzm = 1;//还原
}
if (st[si].lza != 0){
st[si].v = (st[si].v + (se - ss + 1) * st[si].lza) % p;
if (ss != se){
st[left(si)].lza = (st[left(si)].lza + st[si].lza) % p;
st[right(si)].lza = (st[right(si)].lza + st[si].lza) % p;
}
st[si].lza = 0;
}
if (ss > ue || se < us) return;
if (ss >= us && se <= ue){//完全在更新范围内
//先更新自己
if (op == 1){
st[si].v = st[si].v * opt % p;
} else if (op == 2){
st[si].v = (st[si].v + (se - ss + 1) * opt) % p;
}
if (ss != se){//给儿孙设置懒标记
if(op == 1){
st[left(si)].lzm = st[left(si)].lzm * opt % p;
st[left(si)].lza = st[left(si)].lza * opt % p;
st[right(si)].lzm = st[right(si)].lzm * opt % p;
st[right(si)].lza = st[right(si)].lza * opt % p;
} else {
st[left(si)].lza += opt;
st[right(si)].lza += opt;
}
}
return;
}
int mid = getmid(ss, se);
update(ss, mid, left(si), us, ue, op, opt);
update(mid+1, se, right(si), us, ue, op, opt);
st[si].v = (st[left(si)].v + st[right(si)].v) % p;
}
ll query(int ss, int se, int si, int qs, int qe){
if (st[si].lzm != 1){
st[si].v = st[si].v * st[si].lzm % p;//消化
if(ss != se){//下传
st[left(si)].lzm = st[left(si)].lzm * st[si].lzm % p;
st[left(si)].lza = st[left(si)].lza * st[si].lzm % p;
st[right(si)].lzm = st[right(si)].lzm * st[si].lzm % p;
st[right(si)].lza = st[right(si)].lza * st[si].lzm % p;
}
st[si].lzm = 1;//还原
}
if (st[si].lza != 0){
st[si].v = (st[si].v + (se - ss + 1) * st[si].lza) % p;
if (ss != se){
st[left(si)].lza = (st[left(si)].lza + st[si].lza) % p;
st[right(si)].lza = (st[right(si)].lza + st[si].lza) % p;
}
st[si].lza = 0;
}
if(ss >= qs && se <= qe){
return st[si].v;
}
if (ss > qe || se < qs) return 0;
int mid = getmid(ss, se);
return (query(ss, mid, si * 2, qs, qe) + query(mid + 1, se, si * 2 + 1, qs, qe)) % p;
}
int main(){
// freopen("P3373_2.in", "r", stdin);
scanf("%d%d%d", &n, &m, &p);
for(int i = 1; i <= n; i++){
scanf("%d", a + i);
}
build(1, n, 1);
int op, x, y, k;
while(m--){
scanf("%d", &op);
if (op == 1 || op == 2){
scanf("%d%d%d", &x, &y, &k);
update(1, n, 1, x, y, op, k);
} else {
scanf("%d%d", &x, &y);
printf("%lld\n", query(1, n, 1, x, y));
}
}
return 0;
}