解决:对于一株树(无向无环连通图),为每个结点分配对应的权重。要求能高效计算任意两个结点之间的路径的各类信息,其中包括路径长度(路径上所有结点的权重加总),路径中最大权重,最小权重等等。
思路:划分重链和轻边,通过两遍DFS把重链上节点的新编号都连接到一起。
以下代码实现:
操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 2e5+10;
struct Edge{
int to,next;
}edge[MAXN*2];
int head[MAXN],tot;
int top[MAXN]; //v所在的重链的顶端节点
int fa[MAXN]; //父亲节点
int deep[MAXN]; //深度
int num[MAXN]; //v为根的子树节点数
int p[MAXN]; //v对应的位置
int fp[MAXN]; //和p数组相反
int son[MAXN]; //重儿子
int pos;
int mood;
int a[MAXN];
void init(){ //初始化
tot = 0;
memset(head,-1,sizeof(head));
pos = 1;
memset(son,-1,sizeof(son));
}
void addedge(int u,int v){ //加边
edge[tot].to = v;
edge[tot].next = head[u];
head[u] = tot++;
}
void dfs1(int u,int pre, int d){ //当前节点、父节点、层深度
deep[u] = d;
fa[u] = pre;
num[u] = 1;
for(int i = head[u]; i != -1; i = edge[i].next){
int v = edge[i].to;
if(v != pre){
dfs1(v,u,d+1);
num[u] += num[v];
if(son[u] == -1 || num[v] > num[son[u]])
son[u]=v;
}
}
}
void getpos(int u ,int sp){
top[u] = sp;
p[u] = pos++;
fp[p[u]] = u;
if(son[u] == -1) return ;
getpos(son[u],sp);
for(int i = head[u] ; i != -1; i = edge[i].next){
int v = edge[i].to;
if(v != son[u] && v != fa[u])
getpos(v,v);
}
}
//线段树
struct Node{
int l,r;
int val,laz;
}segTree[MAXN*20];
void build(int i,int l,int r){
segTree[i].l = l;
segTree[i].r = r;
segTree[i].laz = 0;
if(l == r){
segTree[i].val = a[fp[l]]%mood;;
return ;
}
int mid = (l+r)/2;
build(i<<1,l,mid);
build((i<<1)|1,mid+1,r);
segTree[i].val = (segTree[i<<1].val+segTree[i<<1|1].val)%mood;
}
void push_up(int i){
segTree[i].val = max(segTree[i<<1].val,segTree[(i<<1)|1].val)%mood;
}
void push_down(int i){
if(segTree[i].laz){
int k = segTree[i].laz;
segTree[i].laz = 0;
int l = segTree[i<<1].l,r = segTree[i<<1].r;
segTree[i<<1].laz += k;
segTree[i<<1].val += (r-l+1)*k%mood;
l = segTree[i<<1|1].l;
r = segTree[i<<1|1].r;
segTree[i<<1|1].laz += k;
segTree[i<<1|1].val += (r-l+1)*k%mood;
}
}
//更新线段树的第K个值为val
void update(int i,int k,int val){
if(segTree[i].l == k && segTree[i].r == k){
segTree[i].val = val;
return ;
}
int mid = (segTree[i].l + segTree[i].r)/2;
if(k <= mid) update(i<<1,k,val);
else update((i<<1)|1,k,val);
push_up(i);
}
//从x到y结点最短路径上所有节点的值都加上z,传参 (1,p[x],p[x]+num[x]-1,z)
void update(int i,int ql,int qr,int val){
int l = segTree[i].l,r = segTree[i].r;
int mid = (l+r)/2;
if(ql <= l && r <= qr){
segTree[i].laz += val;
segTree[i].val += (r-l+1)*val%mood;
return ;
}
push_down(i);
if(ql <= mid) update(i<<1,ql,qr,val);
if(qr > mid) update(i<<1|1,ql,qr,val);
segTree[i].val = (segTree[i<<1].val + segTree[i<<1|1].val)%mood;
}
//查询线段树[l,r]的最大值
int query(int i,int l,int r){
if(segTree[i].l == l && segTree[i].r == r){
return segTree[i].val;
}
int mid = (segTree[i].l + segTree[i].r)/2;
if(r <= mid) query(i<<1,l,r);
else if(l>mid)query((i<<1)|1,l,r);
else return max(query(i<<1,l,mid),query((i<<1)|1,mid+1,r));
}
//查询u-->v边的最大值
int Find(int u,int v){
int f1 = top[u],f2 = top[v];
int tmp = 0;
while(f1 != f2){
if(deep[f1] < deep[f2]){
swap(f1,f2);
swap(v,u);
}
tmp = max(tmp,query(1,p[f1],p[u]));
u = fa[f1];
f1 = top[u];
}
if(u == v) return tmp;
if(deep[u] > deep[v]) swap(u,v);
return max(tmp,query(1,p[son[u]],p[v]));
}
//求以x为根节点的子树内所有节点值之和 传参 (1,p[x],p[x]+num[x]-1)
int querysegTree(int i,int ql,int qr){
int l = segTree[i].l,r = segTree[i].r;
int mid = (l+r)/2;
int ans=0;
if(ql <= l && r <= qr) return segTree[i].val;
push_down(i);
if(ql <= mid) ans += querysegTree(i<<1,ql,qr)%mood;
if(qr > mid) ans+=querysegTree(i<<1|1,ql,qr)%mood;
return ans%mood;
}
void solve1(int x,int y,int z){
while(top[x] != top[y]){
if(deep[top[x]] < deep[top[y]]) swap(x,y);
update(1,p[top[x]],p[x],z);
x = fa[top[x]];
}
if(deep[x] > deep[y]) swap(x,y);
update(1,p[x],p[y],z);
}
int solve2(int x,int y){
int ans = 0;
while(top[x] != top[y]){
if(deep[top[x]] < deep[top[y]]) swap(x,y);
ans += querysegTree(1,p[top[x]],p[x])%mood;
x = fa[top[x]];
}
if(deep[x] > deep[y]) swap(x,y);
ans += querysegTree(1,p[x],p[y])%mood;
return ans%mood;
}
void solve3(int x,int z){
update(1,p[x],num[x]+p[x]-1,z);
}
int solve4(int x){
return querysegTree(1,p[x],num[x]+p[x]-1)%mood;
}
int main(){
int n,m,r;
while(~scanf("%d %d %d %d",&n,&m,&r,&mood)){
init();
for(int i = 1;i <= n;i++){
scanf("%d",&a[i]);
}
for(int i = 1; i < n;i++){
int u,v;
scanf("%d %d",&u,&v);
addedge(u,v);
addedge(v,u);
}
dfs1(r,0,1);
getpos(r,r);
build(1,1,n);
for(int i = 0;i < m;i++){
int q,x,y,z;
scanf("%d",&q);
if(q == 1){
scanf("%d %d %d",&x,&y,&z);
solve1(x,y,z);
}else if(q == 2){
scanf("%d%d",&x,&y);
printf("%d\n",solve2(x,y));
}else if(q == 3){
scanf("%d%d",&x,&z);
solve3(x,z);
}else{
scanf("%d",&x);
printf("%d\n",solve4(x));
}
}
}
return 0;
}
//树状数组
/*
int lowbit(int x){
return x&(-x);
}
int c[MAXN];
int n;
int sum(int i){
int s = 0;
while(i > 0){
s += c[i];
i -=lowbit(i);
}
return s;
}
void add(int i,int val){
while(i <= n){
c[i] += val;
i += lowbit(i);
}
}
//u-->v 的路径上点的纸改变 val
void Change(int u,int v,int val){
int f1 = top[u], f2 = top[v];
int tmp = 0;
while(f1 != f2){
if(deep[f1] < deep[f2]){
swap(f1,f2);
swap(u,v);
}
add(p[f1],val);
add(p[u]+1,-val);
u = fa[f1];
f1 = top[u];
}
if(deep[u] > deep[v])
swap(u,v);
add(p[u],val);
add(p[v]+1,-val);
}
int a[MAXN];
*/