hdu_5405 Sometimes Naive

题目:

  Rhason Cheung had a naive problem, and asked Teacher Mai for help. But Teacher Mai thought this problem was too simple, sometimes naive. So she ask you for help.
  She has a tree with n vertices, numbered from 1 to n. The weight of i-th node is wi.
  You need to support two kinds of operations: modification and query.
  For a modification operation u,w, you need to change the weight of u-th node into w.
  For a query operation , you should output . If there is a vertex on the path from u to v and the path from i to j in the tree, , otherwise . The number can be large, so print the number modulo .

思路:

  如果用容斥的方法统计答案就等于整个树的权值和的平方减去拆去后每个联通块权值和的平方和。那么主要的问题就是考虑怎么快速统计拆去后每个联通块权值和的平方和
  如果考虑每次 查询,最简单的办法就是枚举 上每一个点,再枚举每个点的儿子 。如果 不在 上就加上 子树和的平方。最后在特殊考虑一下 即可(设 在上方)。
  设着优化这个过程,如果在树上轻重链剖分的话,就要能快速的求一条重链的权值和。在每个点上记录这个点所有的轻儿子的子树和的平方和 ,那么所有不是重链末端的点的权值就是 。对于重链末端的点,上一条重链的端点肯定是它的轻儿子,这个值不应被统计,要减去。本身的重儿子却没有被统计,应该加入。用树剖加树状数组维护,复杂度可以做到
  如果用 LCT 维护的话,复杂度可以做到 ,而且不用考虑两条重链之间的问题,但是常数会比较大。

代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 10, MOD = 1e9 + 7;

inline int pw(int x) {
    return (ll) x * x % MOD;
}

template<int N, int M, class T> struct Link {
#define erep(k, G, o) for(int k = G.HEAD[o]; k; k = G.NXT[k])
    int HEAD[N], NXT[M], tot; T W[M];
    void clear() {memset(HEAD, 0, sizeof HEAD); tot = 0;}
    void add(int o, T w) {NXT[++tot] = HEAD[o]; W[HEAD[o] = tot] = w;}
    T& operator[] (int x) {return W[x];}
};
Link<N, N * 2, int> G;

template<int N> struct BIT {
#define low(x) (-x & x)
    int c[N], a;
    void clear() {memset(c, 0, sizeof c);}
    void add(int x, int v) {
        for(; x < N; x += low(x))
            (c[x] += v) %= MOD;
    }
    int qry(int x) {
        for(a = 0; x; x -= low(x))
            (a += c[x]) %= MOD;
        return a;
    }
    int query(int l, int r) {
        return (qry(r) - qry(l - 1) + MOD) % MOD;
    }
};
BIT<N> sum, tre;

int size[N], dep[N], fa[N], hs[N], w[N];
void dfs1(int o, int f) {
    size[o] = 1; fa[o] = f;
    hs[o] = 0; dep[o] = dep[f] + 1;
    erep(k, G, o) {
        int v = G[k];
        if(v == f) continue;
        dfs1(v, o);
        size[o] += size[v];
        if(size[v] > size[hs[o]])
            hs[o] = v;
    }
}
int tp[N], dfn[N], en[N], dfn_cur;
void dfs2(int o, int f, int t) {
    dfn[o] = ++dfn_cur; tp[o] = t;
    tre.add(dfn[o], w[o]);
    if(hs[o]) dfs2(hs[o], o, t);
    erep(k, G, o) {
        int v = G[k];
        if(v == f || v == hs[o]) continue;
        dfs2(v, o, v);
        sum.add(dfn[o], pw(tre.query(dfn[v], en[v])));
    }
    en[o] = dfn_cur;
}
void update(int o, int w) {
    int x = o, p = tre.query(dfn[o], dfn[o]);
    while(tp[o] != 1) {
        o = tp[o];
        int t = tre.query(dfn[o], en[o]);
        sum.add(dfn[fa[o]], (pw((w - p + t) % MOD) - pw(t) + MOD) % MOD);
        o = fa[o];
    }
    tre.add(dfn[x], (w - p + MOD) % MOD);
}
int query(int u, int v) {
    int lastu = 0, lastv = 0;
    int ans = 0, all = tre.query(dfn[1], en[1]);
    while(tp[u] != tp[v]) {
        if(dep[tp[u]] < dep[tp[v]]) {
            swap(u, v);
            swap(lastu, lastv);
        }
        if(hs[u]) (ans += pw(tre.query(dfn[hs[u]], en[hs[u]]))) %= MOD;
        (ans += sum.query(dfn[tp[u]], dfn[u])) %= MOD;
        if(lastu) (ans += -pw(tre.query(dfn[lastu], en[lastu])) + MOD) %= MOD;
        lastu = tp[u];
        u = fa[tp[u]];
    }
    if(dep[u] < dep[v]) {
        swap(u, v);
        swap(lastu, lastv);
    }
    if(hs[u]) (ans += pw(tre.query(dfn[hs[u]], en[hs[u]]))) %= MOD;
    (ans += sum.query(dfn[v], dfn[u])) %= MOD;
    if(lastu) (ans += -pw(tre.query(dfn[lastu], en[lastu])) + MOD) %= MOD;
    if(lastv) (ans += -pw(tre.query(dfn[lastv], en[lastv])) + MOD) %= MOD;
    (ans += pw((all - tre.query(dfn[v], en[v]) + MOD) % MOD)) %= MOD;
    ans = (pw(all) - ans + MOD) % MOD;
    return ans;
}

int main() {
    // freopen("data.in", "r", stdin);
    int n, m;
    while(~scanf("%d%d", &n, &m)) {
        for(int i = 1; i <= n; i++) scanf("%d", &w[i]);
        G.clear(); tre.clear(); sum.clear();
        for(int i = 1; i < n; i++) {
            int u, v;
            scanf("%d%d", &u, &v);
            G.add(u, v); G.add(v, u);
        }
        dfn_cur = 0;
        dfs1(1, 0); dfs2(1, 0, 1);
        for(int i = 1; i <= m; i++) {
            int f, a, b;
            scanf("%d%d%d", &f, &a, &b);
            if(f == 1) update(a, b);
            else printf("%d\n", query(a, b));
        }
    }
    return 0;
}