题解 P5666 【树的重心【民间数据】】

wucstdio

2019-11-18 08:54:03

题解

考场这道题的思考时间比前两道题都短……

果然是学的太多思维僵化了。

首先,题目中给的式子可以等价为枚举一个点,然后查看这个点会作为多少棵子树的重心。

接下来我们的问题就是对每一个点 x ,求出这个点会被统计多少次,也就是求出有多少割掉一棵子树的方案,使得 x 成为重心。

首先拎出来一个重心当做树根,然后 dfs 一遍。

如果 x 不是重心,那么就需要在它父亲的那颗子树中切掉一棵子树:

我们设 x 这棵子树大小为 sx 的子节点中最大的子树大小为 m ,那么切掉的点 del 需要满足如下式子:

\begin{cases}n-s-size[del]\le \dfrac{n-size[del]}{2}\\m\le \dfrac{n-size[del]}{2}\end{cases}

化简可以得到

n-2s\le size[del]\le n-2m

我们用一个树状数组记录这个点父节点的那颗子树中 size=i 的子树数量。在 dfs 的过程中,每 dfs 到一个点 x 就让 c[size[x]]-1,c[n-size[x]]+1 ,也就是当我们割掉 fx 的边的时候切掉的子树大小从 size[x] 变成了 n-size[x] 。统计答案就是区间求和。

但是还有一个问题。这样一来我们统计的答案有可能会包含在 x 的子树中删掉一棵子树的情况,这样是不满足条件的。

为了解决这个问题,我们可以再 dfs 一遍,这一次我们用线段树合并求出每一个点的子树中 size=i 的子树数量。这里统计的是不合法的方案,从答案里面减去就行了。

最后还剩下 x 是重心的情况。由于重心只有最多两个,我们可以先以重心为根对整棵树进行dfs,统计出 x 这个点最大的子树 m_1 和次大的子树 m_2

接下来分情况讨论。如果我们是在最大的子树里面切割,就需要满足

m_2\le \dfrac{n-size[del]}{2}

size[del]\le n-2m_2

否则我们需要满足

m_1\le \dfrac{n-size[del]}{2}

size[del]\le n-2m_1

我们可以在线段树合并的过程中顺便求出这个答案。

总时间复杂度是 O(n\log n) ,在考场电脑上不写读入优化 3s 左右,写读入优化 2.5s 。

下面是代码:

#include<cstdio>
#include<algorithm>
#include<cstring>
#define ll long long
#define lson tree[x].child[0]
#define rson tree[x].child[1]
#define mid (l+r)/2
using namespace std;
struct Edge
{
    int to;
    int nxt;
}e[600005];
struct Tree
{
    int child[2];
    int sum;
}tree[10000005];
int n,m,edgenum,tot,head[300005],pa[300005],size[300005],s[300005],root[300005];
ll ans;
bool isroot[300005];
int read()
{
    char c=(char)getchar();
    while(c>'9'||c<'0')c=(char)getchar();
    int t=0;
    while(c>='0'&&c<='9')
    {
        t=t*10+c-'0';
        c=(char)getchar();
    }
    return t;
}
void add(int u,int v)
{
    e[++edgenum].to=v;
    e[edgenum].nxt=head[u];
    head[u]=edgenum;
}
void dfs_pre(int node)//预处理,找出重心
{
    size[node]=1;
    for(int hd=head[node];hd;hd=e[hd].nxt)
    {
        int to=e[hd].to;
        if(to==pa[node])continue;
        pa[to]=node;
        dfs_pre(to);
        size[node]+=size[to];
        if(size[to]>n/2)isroot[node]=0;
    }
    if(n-size[node]>n/2)isroot[node]=0;
}
void dfs1(int node)//求出所有的size
{
    size[node]=1;
    for(int hd=head[node];hd;hd=e[hd].nxt)
    {
        int to=e[hd].to;
        if(to==pa[node])continue;
        pa[to]=node;
        dfs1(to);
        size[node]+=size[to];
    }
}
void add(int p)
{
    while(p<=n)
    {
        s[p]++;
        p+=p^(p&(p-1));
    }
}
void dec(int p)
{
    while(p<=n)
    {
        s[p]--;
        p+=p^(p&(p-1));
    }
}
int sum(int p)
{
    int ans=0;
    while(p)
    {
        ans+=s[p];
        p&=p-1;
    }
    return ans;
}
int sum(int l,int r)
{
    if(l>r)return 0;
    return sum(r)-sum(l-1);
}//以上是树状数组
int merge(int x,int y)
{
    if(!x||!y)return x+y;
    tree[x].sum+=tree[y].sum;
    lson=merge(lson,tree[y].child[0]);
    rson=merge(rson,tree[y].child[1]);
    return x;
}
void add(int x,int l,int r,int p)
{
    tree[x].sum++;
    if(l==r)return;
    if(p<=mid)
    {
        if(!lson)lson=++tot;
        add(lson,l,mid,p);
    }
    else
    {
        if(!rson)rson=++tot;
        add(rson,mid+1,r,p);
    }
}
//void debug(int x,int l,int r)
//{
//  if(l==r)
//  {
//      printf("%lld ",tree[x].sum);
//      return;
//  }
//  debug(lson,l,mid);
//  debug(rson,mid+1,r);
//}
ll sum(int x,int l,int r,int from,int to)
{
    if(from>to)return 0;
    if(!x)return 0;
    if(l>=from&&r<=to)return tree[x].sum;
    ll ans=0;
    if(from<=mid)ans+=sum(lson,l,mid,from,to);
    if(to>mid)ans+=sum(rson,mid+1,r,from,to);
    return ans;
}//以上是线段树合并
void dfs2(int node)
{
//  printf("%d:\n",node);
    int max1=0,max2=0;
    for(int hd=head[node];hd;hd=e[hd].nxt)
    {
        int to=e[hd].to;
        if(to==pa[node])continue;
        if(size[to]>max1)max2=max1,max1=size[to];
        else if(size[to]>max2)max2=size[to];
    }
//  printf("max1=%d,max2=%d\n",max1,max2);
    if(isroot[node]&&pa[node])
    {
        if(n-size[node]>max1)max2=max1,max1=n-size[node];
        else if(n-size[node]>max2)max2=n-size[node];
        if(n-size[node]==max1)ans+=1ll*node*sum(1,n-2*max2);
        else ans+=1ll*node*sum(1,n-2*max1);
    }
    for(int hd=head[node];hd;hd=e[hd].nxt)
    {
        int to=e[hd].to;
        if(to==pa[node])continue;
        pa[to]=node;
        dec(size[to]);
        add(n-size[to]);
        dfs2(to);
//      printf("to=%d,root=%d,",to,root[to]);
//      debug(root[to],1,n);
//      printf("\n");
        if(isroot[node])
        {
//          printf("%d->%d,range=",node,to);
//          if(size[to]==max1)printf("%d %d,res=%lld\n",1,n-2*max2,sum(root[to],1,n,1,n-2*max2));
//          else printf("%d %d,res=%lld\n",1,n-2*max1,sum(root[to],1,n,1,n-2*max1));
            if(size[to]==max1)ans+=1ll*node*sum(root[to],1,n,1,n-2*max2);
            else ans+=1ll*node*sum(root[to],1,n,1,n-2*max1);
        }
        root[node]=merge(root[node],root[to]);
        add(size[to]);
        dec(n-size[to]);
    }
    if(!isroot[node])
    {
//      printf("node=%d:tot=%lld-%lld\n",node,sum(max(1,n-2*size[node]),min(n,n-2*max1)),sum(root[node],1,n,max(1,n-2*size[node]),min(n,n-2*max1)));
        ans+=1ll*node*sum(max(1,n-2*size[node]),min(n,n-2*max1));
        ans-=1ll*node*sum(root[node],1,n,max(1,n-2*size[node]),min(n,n-2*max1));
    }
    if(isroot[node]&&pa[node])
    {
        if(n-size[node]==max1)ans-=1ll*node*sum(root[node],1,n,1,n-2*max2);
        else ans-=1ll*node*sum(root[node],1,n,1,n-2*max1);
    }
    add(root[node]?root[node]:root[node]=++tot,1,n,size[node]);
//  printf("node=%d:\n",node);
//  printf("tree:");
//  debug(root[node],1,n);
//  printf("\n");
//  printf("array:");
//  for(int i=1;i<=n;i++)printf("%lld ",sum(i)-sum(i-1));
//  printf("\n");
}
int t;
int main()
{
//  freopen("centroid.in","r",stdin);
//  freopen("centroid.out","w",stdout);
    t=read();
    while(t--)
    {
        n=read();
        ans=edgenum=0;
        for(int i=1;i<=tot;i++)tree[i].sum=tree[i].child[0]=tree[i].child[1]=0;
        for(int i=1;i<=n;i++)s[i]=head[i]=pa[i]=size[i]=root[i]=0,isroot[i]=1;
        tot=0;
        for(int i=1;i<n;i++)
        {
            int u=read(),v=read();
            add(u,v);
            add(v,u);
        }
        dfs_pre(1);
        for(int i=1;i<=n;i++)
        {
            if(isroot[i])
            {
                pa[i]=0;
                dfs1(i);
                for(int j=1;j<=n;j++)
                  if(j!=i)add(size[j]);
                dfs2(i);
                break;
            }
        }
        printf("%lld\n",ans);
    }
    return 0;
}