题解 P2664 【树上游戏】

· · 题解

From SLYZ Leaves

没有题解呢

于是来写一份。

最初看到这个题是群里一个dalao安利的。

很明显,这种处理树上点对、nlogn可以接受且单次询问的问题

肯定是点分治跑不掉了。

那么现在问题就变成了:你有一棵树,如何在O(n)的处理出,以根为lca的点对的答案?

一个很重要的性质:对于树中的一点i,如果该点的颜色在该点到根这条链上是第一次出现,那么对于这棵树的其他点j(j与i的lca为根),均能与i的子树(包括i)组成点对,i的颜色会对j的答案贡献size[i]。(我们在此暂且不考虑j到根的链上是否出现了i的颜色)

这个性质还是很好证明的,大家画画图就能看懂了。

那么我们就可以这样做了:

1.对树进行第一遍dfs,预处理size和上方性质中的贡献。(开一个color数组即可记录贡献),同时记录贡献总和sum

2.枚举跟的所有子树,先把子树扫一遍清除其在color数组中的所有贡献。接着,对于该子树中的每一个点j:

设X=sigma color[j 到根上(不包括根)的所有颜色] 由于这些颜色已经出现过,我们不能在该子树外计算其贡献)

设num为j到根上(不包括根)的颜色数

设Y为size[root]-size[该子树](即所有其他子树+根的点数)

则ans[j]+=sum-X+num*Y;

3.别忘了计算单独root的ans

  ans[root]+=sum-color[根的颜色]+size[root]

4.清空贡献数组以及其他东西

那么点分治就解决了~

附上代码:

    #include<cstdio>
    #include<cstring>
    #include<iostream>
    #include<algorithm>
    #define o 200011
    #define ll long long
    using namespace std;
    const int inf=1e8;
    int head[o],nxt[o*2],point[o*2],V[o];
    ll color[o],ans[o],much,sum,num,size[o],cnt[o],total,record;
    int tot,n,ui,vi,root;
    bool vis[o*2];
    void addedge(int x,int y){
        tot++;nxt[tot]=head[x];head[x]=tot;point[tot]=y;
        tot++;nxt[tot]=head[y];head[y]=tot;point[tot]=x;
    }
    void findroot(int now,int dad){
        size[now]=1;
        ll maxson=0;
        for(int tmp=head[now];tmp;tmp=nxt[tmp]){
            int v=point[tmp];
            if(v!=dad&&!vis[tmp]){
                findroot(v,now);
                size[now]+=size[v];
                maxson=max(maxson,size[v]);
            }
        }
        maxson=max(maxson,total-size[now]);
        if(maxson<record) root=now,record=maxson;
    }        
    void dfs1(int now,int dad){
        size[now]=1;
        cnt[V[now]]++;
        for(int tmp=head[now];tmp;tmp=nxt[tmp]){
            int v=point[tmp];
            if(!vis[tmp]&&v!=dad){
                dfs1(v,now);
                size[now]+=size[v];
            }
        }
        if(cnt[V[now]]==1){
            sum+=size[now];
            color[V[now]]+=size[now];
        } 
        cnt[V[now]]--;
    }
    void change(int now,int dad,int value){
        cnt[V[now]]++; 
        for(int tmp=head[now];tmp;tmp=nxt[tmp]){
            int v=point[tmp];
            if(!vis[tmp]&&v!=dad) change(v,now,value);
        }
        if(cnt[V[now]]==1){
            sum+=(ll)size[now]*value;
            color[V[now]]+=(ll)size[now]*value;
        } 
        cnt[V[now]]--;
    }
    void dfs2(int now,int dad){
        cnt[V[now]]++;
        if(cnt[V[now]]==1){
            sum-=color[V[now]];
            num++;
        }
        ans[now]+=sum+num*much;
        for(int tmp=head[now];tmp;tmp=nxt[tmp]){
            int v=point[tmp];
            if(!vis[tmp]&&v!=dad) dfs2(v,now);
        }    
        if(cnt[V[now]]==1){
            sum+=color[V[now]];
            num--;
        }
        cnt[V[now]]--;
    }
    void clear(int now,int dad){
        cnt[V[now]]=0;
        color[V[now]]=0;
        for(int tmp=head[now];tmp;tmp=nxt[tmp]){
            int v=point[tmp];
            if(!vis[tmp]&&v!=dad) clear(v,now);
        }
    }
    void did(int now,int dad){
        dfs1(now,dad);
        ans[now]+=sum-color[V[now]]+size[now];
        for(int tmp=head[now];tmp;tmp=nxt[tmp]){
            int v=point[tmp];
            if(!vis[tmp]&&v!=dad){
                cnt[V[now]]++;
                sum-=size[v];
                color[V[now]]-=size[v];
                change(v,now,-1);
                cnt[V[now]]--;
                much=size[now]-size[v];
                dfs2(v,now);
                cnt[V[now]]++;
                sum+=size[v];
                color[V[now]]+=size[v];
                change(v,now,1);
                cnt[V[now]]--;
            }
        }
        sum=0;num=0;
        clear(now,dad);
    }
    void solve(int now,int dad){
        did(now,dad);
        for(int tmp=head[now];tmp;tmp=nxt[tmp]){
            int v=point[tmp];
            if(!vis[tmp]&&v!=dad){
                vis[tmp]=true;
                vis[tmp^1]=true;
                total=size[v];
                record=inf;
                findroot(v,now);
                solve(root,0);
            }
        }
    }
    int main(){
        tot=1;
        scanf("%d",&n);
        for(int i=1;i<=n;i++) scanf("%d",&V[i]);
        for(int i=1;i<n;i++){
            scanf("%d%d",&ui,&vi);
            addedge(ui,vi);
        }
        record=inf;
        total=n;
        findroot(1,0);
        solve(root,0);
        for(int i=1;i<=n;i++) printf("%lld\n",ans[i]);
        return 0;
    }