清华集训2017 Hello world!

题目:

  不远的一年前,小V还是一名清华集训的选手,坐在机房里为他已如风中残烛的OI生涯做最后的挣扎。而如今,他已成为了一名光荣的出题人。他感到非常激动,不禁感叹道:“Hello world!”。
  小V有 道题,他的题都非常毒瘤,所以关爱选手的 ufozgg 打算削弱这些题。为了逃避削弱,小V把他的毒瘤题都藏到了一棵 个节点的树里(节点编号从 ),这棵树上的所有节点与小V的所有题一一对应。小V的每一道题都有一个毒瘤值,节点 (表示标号为 的树上节点,下同)对应的题的毒瘤值为
  魔法师小V为了保护他的题目,对这棵树施了魔法,这样一来,任何人想要一探这棵树的究竟,都必须在上面做跳跃操作。每一次跳跃操作包含一个起点 、一个终点 和一个步频 ,这表示跳跃者会从 出发,在树上沿着简单路径多次跳跃到达 ,每次跳跃,如果从当前点到 的最短路长度不超过 ,那么跳跃者就会直接跳到 ,否则跳跃者就会沿着最短路跳过恰好 条边。
  既然小V把题藏在了树里,ufozgg 就不能直接削弱题目了。他就必须在树上跳跃,边跳跃边削弱题目。ufozgg 每次跳跃经过一个节点(包括起点 ,当 的时候也是如此),就会把该节点上的题目的毒瘤值开根并向下取整:即如果他经过了节点 ,他就会使 。这种操作我们称为削弱操作。
  ufozgg 还会不时地希望知道他对题目的削弱程度。因此,他在一些跳跃操作中会放弃对题目的削弱,转而统计该次跳跃经过节点的题目毒瘤值总和。这种操作我们称为统计操作。
  吃瓜群众绿绿对小V的毒瘤题和 ufozgg 的削弱计划常感兴趣。他现在想知道 ufozgg 每次做统计操作时得到的结果。你能帮帮他吗?

思路:

  因为 的递减速度很快( 次就变成 了),所以每个点只有前 次削弱操作是有意义的。
  假设所有的 都等于 ,每次修改时用并查集跳过 的点后暴力修改,查询时用树状数组和 序来维护每个点到根路径上的 。就能做到 了。但如果 不同,那就得对于每个 分开维护,这样最多需要 个并查集和树状数组,显然是不行的。
  注意到每次经过的点只有 个,如果用树剖每次往上跳 个点,每次复杂度只有
  所以可以考虑对操作分类,如果 ,就用第二种方法,如果 就用第一种。每次修改每个点时需要同时修改 个树状数组,所以复杂度应该是 。尝试了一下发现 时最快。
  在链上跳到 的时候有一些细节,一定要细心!!!

代码:

#include <bits/stdc++.h>
#define debug(format, ...) fprintf(stderr, format, __VA_ARGS__)
#define File(_) freopen(#_".in", "r", stdin), freopen(#_".out", "w", stdout)
typedef long long ll;
const int N = 50005, M = 35;

template<int N, int M, class T> struct Link {
#define erep(k, G, o) for(int k = G.HEAD[o]; k; k = G.NXT[k])
    int HEAD[N], NXT[M], tot; T W[M];
    void add(int x, T w) {NXT[++tot] = HEAD[x]; W[HEAD[x] = tot] = w;}
    T& operator[] (int x) {return W[x];}
};
Link<N, N * 2, int> G;

struct Union_Find {
    int a[N];
    void init(int n) {for(int i = 1; i <= n; i++) a[i] = i;}
    int find(int x) {return a[x] == x ? x : a[x] = find(a[x]);}
    int& operator[] (int x) {find(x); return a[x];}
} ufs[M];

int fa[N][M], tp[N], dfn[N], mp[N], dep[N], sz[N], hs[N], cur_dfn;
void dfs1(int o, int f) {
    fa[o][1] = f; dep[o] = dep[f] + 1; sz[o] = 1;
    erep(k, G, o) {
        int v = G[k];
        if(v == f) continue;
        dfs1(v, o);
        sz[o] += sz[v];
        if(sz[hs[o]] < sz[v]) hs[o] = v;
    }
}
void dfs2(int o, int f, int t) {
    tp[o] = t; dfn[o] = ++cur_dfn; mp[dfn[o]] = o;
    if(hs[o]) dfs2(hs[o], o, t);
    erep(k, G, o) {
        int v = G[k];
        if(v == f || v == hs[o]) continue;
        dfs2(v, o, v);
    }
}
int lca(int u, int v) {
    while(tp[u] != tp[v]) {
        if(dep[fa[tp[u]][1]] < dep[fa[tp[v]][1]])
            std::swap(u, v);
        u = fa[tp[u]][1];
    }
    return dep[u] < dep[v] ? u : v;
}
int jmp(int o, int k) {
    if(dep[o] <= k) return 0;
    while(dep[o] - dep[fa[tp[o]][1]] <= k) {
        k -= dep[o] - dep[fa[tp[o]][1]];
        o = fa[tp[o]][1];
    }
    return mp[dfn[o] - k];
}

struct BIT {
#define low(x) (-(x) & (x))
    ll c[N], a;
    void upd(int w, ll x) {
        for(; w < N; w += low(w)) c[w] += x;
    }
    ll qry(int w) {
        for(a = 0; w; w -= low(w)) a += c[w];
        return a;
    }
};

ll a[N];

struct Sum {
    Link<N, N, int> G;
    BIT sum;
    int dfn[N], en[N], cur;
    Sum() {cur = 0;}
    void dfs(int o) {
        dfn[o] = ++cur;
        erep(k, G, o) dfs(G[k]);
        en[o] = cur;
    }
    void upd(int o, ll x) {sum.upd(dfn[o], x); sum.upd(en[o] + 1, -x);}
    ll qry(int o) {return sum.qry(dfn[o]);}
    void init(int k, int n) {
        for(int i = 1; i <= n; i++) G.add(fa[i][k], i);
        dfs(0);
        for(int i = 1; i <= n; i++) upd(i, a[i]);
    }
} sum[M];

void sqrt_node(int o) {
    if(ufs[1][o] != o) return;
    ll t = std::sqrt(a[o]);
    for(int i = 1; i < M; i++) sum[i].upd(o, t - a[o]);
    if(t == 1) for(int i = 1; i < M; i++) ufs[i][o] = ufs[i][fa[o][i]];
    a[o] = t;
}
void upd1(int u, int v, int k) {
    int p = lca(u, v);
    int t = ((dep[u] - dep[p]) % k), x = (dep[v] - dep[p] + t) % k;
    while(dep[u = ufs[k][u]] >= dep[p]) {
        sqrt_node(u);
        u = fa[u][k];
    }
    if(x != 0) sqrt_node(v);
    v = fa[v][x];
    while(dep[v = ufs[k][v]] > dep[p]) {
        sqrt_node(v);
        v = fa[v][k];
    }
}
void upd2(int u, int v, int k) {
    int p = lca(u, v);
    int t = ((dep[u] - dep[p]) % k), x = (dep[v] - dep[p] + t) % k;
    while(dep[u] >= dep[p]) {
        sqrt_node(u);
        u = jmp(u, k);
    }
    if(x != 0) sqrt_node(v);
    v = jmp(v, x);
    while(dep[v] > dep[p]) {
        sqrt_node(v);
        v = jmp(v, k);
    }
}

ll qry1(int u, int v, int k) {
    int p = lca(u, v);
    bool flag = false;
    int t = dep[u] - dep[p], x = t % k;
    ll ans = sum[k].qry(u) - sum[k].qry(jmp(u, t - x + k));
    if(jmp(u, t - x) == p) flag = true;
    t = ((dep[u] - dep[p]) % k), x = (dep[v] - dep[p] + t) % k;
    if(x != 0) ans += a[v];
    v = fa[v][x];
    if(dep[v] < dep[p]) return ans;
    t = dep[v] - dep[p]; x = t % k;
    ans += sum[k].qry(v) - sum[k].qry(jmp(v, t - x + k));
    if(flag) ans -= a[p];
    return ans;
}
ll qry2(int u, int v, int k) {
    int p = lca(u, v);
    int t = ((dep[u] - dep[p]) % k), x = (dep[v] - dep[p] + t) % k;
    ll ans = 0;
    while(dep[u] >= dep[p]) {
        ans += a[u];
        u = jmp(u, k);
    }
    if(x != 0) ans += a[v];
    v = jmp(v, x);
    while(dep[v] > dep[p]) {
        ans += a[v];
        v = jmp(v, k);
    }
    return ans;
}

int main() {
    File(hello);
    int n, q;
    scanf("%d", &n);
    for(int i = 1; i <= n; i++) scanf("%lld", a + i);
    for(int i = 1, u, v; i < n; i++) {
        scanf("%d%d", &u, &v);
        G.add(u, v); G.add(v, u);
    }
    dfs1(1, 0); dfs2(1, 0, 1);
    sum[1].init(1, n); ufs[1].init(n);
    for(int i = 1; i <= n; i++) fa[i][0] = i;
    for(int k = 2; k < M; k++) {
        for(int i = 1; i <= n; i++)
            fa[i][k] = fa[fa[i][k - 1]][1];
        sum[k].init(k, n); ufs[k].init(n);
    }
    scanf("%d", &q);
    for(int i = 1; i <= q; i++) {
        int op, u, v, k;
        // debug("%d\n", i);
        scanf("%d%d%d%d", &op, &u, &v, &k);
        if(op == 0) {
            if(u == v) {sqrt_node(u); continue;}
            if(k < M) upd1(u, v, k);
            else upd2(u, v, k);
        }
        else {
            if(u == v) {printf("%lld\n", a[u]); continue;}
            if(k < M) printf("%lld\n", qry1(u, v, k));
            else printf("%lld\n", qry2(u, v, k));
        }
    }
    return 0;
}