洛谷 P3384 【模板】树链剖分

题目链接

https://www.luogu.org/problem/P3384

分析

树链剖分是将树适当划分成链,且对其修改和查询操作能转化为对序列的操作,并用线段树或平衡树维护对应序列。

AC代码

#include <cstdio>
#include <iostream>
#include <algorithm>

using namespace std;

inline int get_num() {
    int num = 0;
    char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9')
        num = num * 10 + c - '0', c = getchar();
    return num;
}

const int maxn = 1e5 + 5;

int head[maxn], eid;

struct Edge {
    int v, next;
} edge[2 * maxn];

inline void insert(int u, int v) {
    edge[++eid].v = v;
    edge[eid].next = head[u];
    head[u] = eid;
}

int n, m, r, mod, a[maxn], size[maxn], depth[maxn], f[maxn];
int chain[maxn], first[maxn], tot = 1;
int order[maxn], dfs_order[maxn], cnt;

int dfs0(int u, int fa) {
    size[u] = 1;
    depth[u] = depth[fa] + 1;
    f[u] = fa;
    for (int p = head[u]; p; p = edge[p].next) {
        int v = edge[p].v;
        if (v != fa) size[u] += dfs0(v, u);
    }
    return size[u];
}

void dfs(int u, int fa, int ch) {
    chain[u] = ch;
    order[u] = ++cnt;
    dfs_order[cnt] = u;
    int id = 0;
    for (int p = head[u]; p; p = edge[p].next) {
        int v = edge[p].v;
        if (v != fa && size[v] > size[id]) id = v;
    }
    if (!id) return;
    dfs(id, u, ch);
    for (int p = head[u]; p; p = edge[p].next) {
        int v = edge[p].v;
        if (v != fa && v != id) {
            first[++tot] = v;
            dfs(v, u, tot);
        }
    }
}

struct ST {
    int l, r, sum, tag;
} st[4 * maxn];

inline void up(int p) {
    st[p].sum = (st[2 * p].sum + st[2 * p + 1].sum) % mod;
}

inline void mark(int p, int d) {
    st[p].sum = (st[p].sum + d * (st[p].r - st[p].l + 1)) % mod;
    st[p].tag = (st[p].tag + d) % mod;
}

inline void down(int p) {
    if (st[p].tag) {
        mark(2 * p, st[p].tag);
        mark(2 * p + 1, st[p].tag);
        st[p].tag = 0;
    }
}

void build(int p, int l, int r) {
    st[p].l = l, st[p].r = r;
    if (l == r) {
        st[p].sum = a[dfs_order[l]] % mod;
        return;
    }
    int mid = (l + r) / 2;
    build(2 * p, l, mid);
    build(2 * p + 1, mid + 1, r);
    up(p);
}

int query0(int p, int l, int r) {
    if (l <= st[p].l && st[p].r <= r) return st[p].sum;
    down(p);
    int mid = (st[p].l + st[p].r) / 2, ret = 0;
    if (l <= mid) ret = (ret + query0(2 * p, l, r)) % mod;
    if (r > mid) ret = (ret + query0(2 * p + 1, l, r)) % mod;
    return ret;
}

void modify0(int p, int l, int r, int d) {
    if (l <= st[p].l && st[p].r <= r) {
        mark(p, d);
        return;
    }
    down(p);
    int mid = (st[p].l + st[p].r) / 2;
    if (l <= mid) modify0(2 * p, l, r, d);
    if (r > mid) modify0(2 * p + 1, l, r, d);
    up(p);
}

inline int query(int p1, int p2) {
    int ans = 0;
    while (chain[p1] != chain[p2]) {
        if (depth[first[chain[p1]]] < depth[first[chain[p2]]]) swap(p1, p2);
        int p = first[chain[p1]];
        ans = (ans + query0(1, order[p], order[p1])) % mod;
        p1 = f[p];
    }
    if (depth[p1] > depth[p2]) swap(p1, p2);
    ans = (ans + query0(1, order[p1], order[p2])) % mod;
    return ans;
}

inline void modify(int p1, int p2, int d) {
    while (chain[p1] != chain[p2]) {
        if (depth[first[chain[p1]]] < depth[first[chain[p2]]]) swap(p1, p2);
        int p = first[chain[p1]];
        modify0(1, order[p], order[p1], d);
        p1 = f[p];
    }
    if (depth[p1] > depth[p2]) swap(p1, p2);
    modify0(1, order[p1], order[p2], d);
}

int main() {
    n = get_num(), m = get_num(), r = get_num(), mod = get_num();
    for (int i = 1; i <= n; ++i) a[i] = get_num();
    for (int i = 1; i <= n - 1; ++i) {
        int x = get_num(), y = get_num();
        insert(x, y);
        insert(y, x);
    }
    dfs0(r, 0);
    first[tot] = r;
    dfs(r, 0, tot);
    build(1, 1, n);
    for (int i = 1; i <= m; ++i) {
        int op = get_num();
        if (op == 1) {
            int x = get_num(), y = get_num(), z = get_num();
            modify(x, y, z);
        }
        else if (op == 2) {
            int x = get_num(), y = get_num();
            printf("%d\n", query(x, y));
        }
        else if (op == 3) {
            int x = get_num(), y = get_num();
            modify0(1, order[x], order[x] + size[x] - 1, y);
        }
        else {
            int x = get_num();
            printf("%d\n", query0(1, order[x], order[x] + size[x] - 1));
        }
    }
    return 0;
}
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容