hdu_5293 Tree chain problem

题目:

  Coco has a tree, whose vertices are conveniently labeled by 1,2,…,n.
  There are m chain on the tree, Each chain has a certain weight. Coco would like to pick out some chains any two of which do not share common vertices.
  Find out the maximum sum of the weight Coco can pick.

思路:

  一开始竟然没想到树形dp,QAQ。
  设为以为根的子树中,可取到的答案的最大值。如果要取链,那么路径上的所有点都要被删去,剩下能取到的值就是(因为路径上的点的儿子可能也在路径上,要减去)。设,用两个树状数组维护,用树链剖分求这个式子,复杂度。如果用dfs序或者是并查集可以写到。但是不会快多少。

代码:

#include <bits/stdc++.h>
#define File(_) freopen(#_".in","r",stdin),freopen(#_".out","w",stdout)
#define FOR(i,a,b) for(int i=(a),i##END=(b);i<=i##END;i++)
#define DOR(i,a,b) for(int i=(a),i##END=(b);i>=i##END;i--)
#define clr(a,b) memset(a,b,sizeof a)
#define MN 100005
#define ll long long 
using namespace std;
template<class T> bool tomax(T &a,T b){return a<b?a=b,1:0;}
template<class T> bool tomin(T &a,T b){return a>b?a=b,1:0;}
template<int N,int M,class T> struct Link{
    int HEAD[N],NXT[M],tot;T W[M];
    void clear(){clr(HEAD,0);tot=0;}
    void add(int x,T w){NXT[++tot]=HEAD[x];W[HEAD[x]=tot]=w;}
    T& operator [] (int x){return W[x];}
    #define EOR(k,G,o) for(int k=G.HEAD[o];k;k=G.NXT[k])
};
Link<MN,MN<<1,int> G;
template<int N> struct BIT{
    int s[N],a;
    #define low(x) (-(x)&(x))
    void upd(int x,int d){for(;x<N;x+=low(x))s[x]+=d;}
    int qry(int x){for(a=0;x;x-=low(x))a+=s[x];return a;}
    int qry(int l,int r){return qry(r)-qry(l-1);}
    void clear(){clr(s,0);}
};
BIT<MN> b1,b2;
int hs[MN],siz[MN],tp[MN],dfn[MN],dep[MN],fa[MN];
void dfs1(int o,int f){
    siz[o]=1;hs[o]=0;
    dep[o]=dep[f]+1;fa[o]=f;
    EOR(k,G,o){
        int v=G[k];
        if(v==f)continue;
        dfs1(v,o);
        siz[o]+=siz[v];
        if(siz[v]>siz[hs[o]])hs[o]=v;
    }
}
int No;
void dfs2(int o,int f,int TP){
    if(!f)No=0;
    tp[o]=TP;dfn[o]=++No;
    if(hs[o])dfs2(hs[o],o,TP);
    EOR(k,G,o){
        int v=G[k];
        if(v==f||v==hs[o])continue;
        dfs2(v,o,v);
    }
}
int calc(int u,int v,BIT<MN> &bit){
    int ans=0;
    while(tp[u]!=tp[v]){
        if(dep[fa[tp[u]]]<dep[fa[tp[v]]])
            swap(u,v);
        ans+=bit.qry(dfn[tp[u]],dfn[u]);
        u=fa[tp[u]];
    }
    if(dep[u]<dep[v])swap(u,v);
    ans+=bit.qry(dfn[v],dfn[u]);
    return ans;
}
int lca(int u,int v){
    while(tp[u]!=tp[v]){
        if(dep[fa[tp[u]]]<dep[fa[tp[v]]])
            swap(u,v);
        u=fa[tp[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    return u;
}
int dp[MN],sum[MN];
struct Chain{int u,v,w;};
vector<Chain> ch[MN];
void dfs_dp(int o,int f){
    sum[o]=0;
    EOR(k,G,o){
        int v=G[k];
        if(v==f)continue;
        dfs_dp(v,o);
        sum[o]+=dp[v];
    }
    dp[o]=sum[o];
    b1.upd(dfn[o],sum[o]);
    FOR(i,0,(int)ch[o].size()-1){
        Chain c=ch[o][i];
        tomax(dp[o],c.w+calc(c.u,c.v,b1)-calc(c.u,c.v,b2));
    }
    b2.upd(dfn[o],dp[o]);
}
int main(){
    int n,m,T;
    scanf("%d",&T);
    while(T--){
        scanf("%d%d",&n,&m);
        FOR(i,1,n)ch[i].clear();
        G.clear();b1.clear();b2.clear();
        FOR(i,1,n-1){
            int a,b;
            scanf("%d%d",&a,&b);
            G.add(a,b);G.add(b,a);
        }
        dfs1(1,0);dfs2(1,0,1);
        FOR(i,1,m){
            Chain c;
            scanf("%d%d%d",&c.u,&c.v,&c.w);
            ch[lca(c.u,c.v)].push_back(c);
        }
        dfs_dp(1,0);
        printf("%d\n",dp[1]);
    }
    return 0;
}