LUOGU 3384
Description
如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1
格式 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2
格式 2 x y
表示求树从x到y结点最短路径上所有节点的值之和
操作3
格式 3 x z
表示将以x为根节点的子树内所有节点值都加上z
操作4
格式 4 x
表示求以x为根节点的子树内所有节点值之和
Input Format
第一行包含4个正整数,分别表示树的结点个数、操作个数、根节点序号和模数(所有的输出结果均对此取模)。
第二行包含个非负整数,依次表示各个节点上初始的权值。
接下来每行包含两个整数,表示点和点之间连有一条边。
接下来行每行包含若干个正整数,每行表示一个操作,格式如题。
Output Format
输出包含若干行,依次表示每个操作2或操作4所得的结果。
Sample Input
5 5 2 24
7 3 7 8 0
1 2
1 5
3 1
4 1
3 4 2
3 2 2
4 5
1 5 1 3
2 1 3
Sample Output
2
21
Constraints
对于%的数据:
对于%的数据:
对于%的数据:
CCYOS
学习树链剖分的板子题。
- 概念
重儿子 所有对于一个结点的所有子树,最大的子树的根节点是重儿子。
轻儿子 不是重儿子的其他子节点。
重边 链接该结点和它的的重儿子的边。
轻边 不是重边的其他树边。
重链 连续的重边。特别的,对于每个轻儿子有一条到自己本身的重链。重链的两端必然是轻边。
- 需要求出的数组:
dfs1 | dfs2 |
---|---|
fa[i] - i的父亲 | top[i] - i所在重链深度最小的节点 |
dep[i] - i的深度 | seg[i] - i在新序列中的编号 |
size[i] - 以i为根的子树大小 | rev[i] - 新序列中编号为i的树上节点 |
son[i] - i的重儿子 |
- 这个序列用线段树维护。
由于是dfs,且建立新序列时重儿子有先,所以一条重链上的编号是连续的,一颗子树上的编号也是连续的。
注意
a)线段树不要写错。
b)更新重儿子时初始比较值为0。
c)好事多模,但是别模太多。
code
#include<bits/stdc++.h>
using namespace std;
#define mxn 100005
int N,M,R,P,tot;
int size[mxn],fa[mxn],dep[mxn],top[mxn],son[mxn],num[mxn],rev[mxn],seg[mxn],head[mxn],sum[mxn << 2],tag[mxn << 2];
struct edge{
int to,nxt;
}edg[mxn << 1];
inline int read(){
char c = getchar();
int fl = 1,ret = 0;
for(;!isdigit(c) && c != '-';c = getchar())if(c == '-')fl = 0;
for(;isdigit(c);c = getchar())ret = (ret << 3) + (ret << 1) + c - 48;
return fl ? ret : -ret;
}
inline void add(int x,int y){
edg[++tot].to = y;
edg[tot].nxt = head[x];
head[x] = tot;
}
void dfs1(int u,int f){
fa[u] = f;size[u] = 1;
dep[u] = dep[f] + 1;
int mson = 0;//!!!
for(int e = head[u];e;e = edg[e].nxt){
int to = edg[e].to;
if(to == f)continue;
dfs1(to,u);
size[u] += size[to];
if(size[to] > mson)son[u] = to,mson = size[to];
}
}
void dfs2(int u,int tpf){
seg[u] = ++seg[0];
top[u] = tpf;
rev[seg[0]] = u;
if(!son[u])return;
dfs2(son[u],tpf);
for(int e = head[u];e;e = edg[e].nxt){
int to = edg[e].to;
if(to == fa[u]||to == son[u])continue;
dfs2(to,to);
}
}
#define ls (p << 1)
#define rs (p << 1 | 1)
inline void pushdown(int p,int l,int r){
int mid = (l + r) >> 1;
(tag[ls] += tag[p])%= P;(tag[rs] += tag[p]) %= P;
(sum[ls] += tag[p] * (mid - l + 1)) %= P;
(sum[rs] += tag[p] * (r - mid)) %= P;
tag[p] = 0;
}
inline void build(int l,int r,int p){
if(l == r){
sum[p] = num[rev[l]];
//cout<<p<<" "<<sum[p]<<endl;
return;
}
int mid = (l + r)>>1;;
build(l,mid,ls);build(mid + 1,r,rs);
sum[p] = (sum[ls] + sum[rs])%P;
}
inline void updS(int l,int r,int ql,int qr,int k,int p){
if(ql <= l && qr >= r){
tag[p] += k;
(sum[p] += k * (r - l + 1)) %= P;
return;
}
pushdown(p,l,r);
int mid = (l + r)>>1;
if(ql <= mid)updS(l,mid,ql,qr,k,ls);
if(qr > mid)updS(mid + 1,r,ql,qr,k,rs);
sum[p] = (sum[ls] + sum[rs]) % P;
}
inline int ask(int l,int r,int ql,int qr,int p){
if(ql <= l && qr >= r)
return sum[p];
pushdown(p,l,r);
int mid = (l + r) >> 1;
int ret = 0;
if(ql <= mid) (ret += ask(l,mid,ql,qr,ls)) %= P;
if(qr > mid) (ret += ask(mid + 1,r,ql,qr,rs)) %= P;
return ret;
}
inline void c1(int x,int y,int z){
z %= P;
int tx = top[x],ty = top[y];
while(tx != ty){
if(dep[tx] < dep[ty])swap(x,y),swap(tx,ty);
updS(1,N,seg[tx],seg[x],z,1);
x = fa[tx];tx = top[x];
}
if(dep[x] > dep[y])swap(x,y);
updS(1,N,seg[x],seg[y],z,1);
}
inline void c2(int x,int y){
int tx = top[x],ty = top[y];
int ans = 0;
while(tx != ty){
if(dep[tx] < dep[ty])swap(x,y),swap(tx,ty);
(ans += ask(1,N,seg[tx],seg[x],1)) %= P;
x = fa[tx];tx = top[x];
}
if(dep[x] > dep[y])swap(x,y);
(ans += ask(1,N,seg[x],seg[y],1)) %= P;
printf("%d\n",ans);
}
inline void c3(int x,int y){
updS(1,N,seg[x],seg[x] + size[x] - 1,y,1);
}
inline void c4(int x){
int ans = ask(1,N,seg[x],seg[x] + size[x] - 1,1) % P;
printf("%d\n",ans);
}
int main(){
N = read();M = read();R = read();P = read();
for(int i = 1;i <= N;++i)num[i] = read(),num[i] %= P;
for(int i = 1;i < N;++i){
int x = read();
int y = read();
add(x,y);add(y,x);
}
dfs1(R,0);
dfs2(R,R);
// for(int i = 1;i <= 5;++i)cout<<rev[i]<<endl;
build(1,N,1);
for(int i = 1;i <= M;++i){
int op = read();
int x,y,z;
if(op == 1){
x = read(),y = read(),z = read();
c1(x,y,z);
// for(int j = 1;j <= 9;++j)cout<<sum[j]<<" ";
// cout<<endl;
}
if(op == 2){
x = read(),y = read();c2(x,y);
}
if(op == 3){
x = read(),y = read();c3(x,y);
//for(int j = 1;j <= 9;++j)cout<<sum[j]<<" ";
// cout<<endl;
}
if(op == 4){
x = read();c4(x);
}
}
return 0;
}