wucstdio
2019-11-18 08:54:03
考场这道题的思考时间比前两道题都短……
果然是学的太多思维僵化了。
首先,题目中给的式子可以等价为枚举一个点,然后查看这个点会作为多少棵子树的重心。
接下来我们的问题就是对每一个点
首先拎出来一个重心当做树根,然后 dfs 一遍。
如果
我们设
化简可以得到
我们用一个树状数组记录这个点父节点的那颗子树中
但是还有一个问题。这样一来我们统计的答案有可能会包含在
为了解决这个问题,我们可以再 dfs 一遍,这一次我们用线段树合并求出每一个点的子树中
最后还剩下
接下来分情况讨论。如果我们是在最大的子树里面切割,就需要满足
即
否则我们需要满足
即
我们可以在线段树合并的过程中顺便求出这个答案。
总时间复杂度是
下面是代码:
#include<cstdio>
#include<algorithm>
#include<cstring>
#define ll long long
#define lson tree[x].child[0]
#define rson tree[x].child[1]
#define mid (l+r)/2
using namespace std;
struct Edge
{
int to;
int nxt;
}e[600005];
struct Tree
{
int child[2];
int sum;
}tree[10000005];
int n,m,edgenum,tot,head[300005],pa[300005],size[300005],s[300005],root[300005];
ll ans;
bool isroot[300005];
int read()
{
char c=(char)getchar();
while(c>'9'||c<'0')c=(char)getchar();
int t=0;
while(c>='0'&&c<='9')
{
t=t*10+c-'0';
c=(char)getchar();
}
return t;
}
void add(int u,int v)
{
e[++edgenum].to=v;
e[edgenum].nxt=head[u];
head[u]=edgenum;
}
void dfs_pre(int node)//预处理,找出重心
{
size[node]=1;
for(int hd=head[node];hd;hd=e[hd].nxt)
{
int to=e[hd].to;
if(to==pa[node])continue;
pa[to]=node;
dfs_pre(to);
size[node]+=size[to];
if(size[to]>n/2)isroot[node]=0;
}
if(n-size[node]>n/2)isroot[node]=0;
}
void dfs1(int node)//求出所有的size
{
size[node]=1;
for(int hd=head[node];hd;hd=e[hd].nxt)
{
int to=e[hd].to;
if(to==pa[node])continue;
pa[to]=node;
dfs1(to);
size[node]+=size[to];
}
}
void add(int p)
{
while(p<=n)
{
s[p]++;
p+=p^(p&(p-1));
}
}
void dec(int p)
{
while(p<=n)
{
s[p]--;
p+=p^(p&(p-1));
}
}
int sum(int p)
{
int ans=0;
while(p)
{
ans+=s[p];
p&=p-1;
}
return ans;
}
int sum(int l,int r)
{
if(l>r)return 0;
return sum(r)-sum(l-1);
}//以上是树状数组
int merge(int x,int y)
{
if(!x||!y)return x+y;
tree[x].sum+=tree[y].sum;
lson=merge(lson,tree[y].child[0]);
rson=merge(rson,tree[y].child[1]);
return x;
}
void add(int x,int l,int r,int p)
{
tree[x].sum++;
if(l==r)return;
if(p<=mid)
{
if(!lson)lson=++tot;
add(lson,l,mid,p);
}
else
{
if(!rson)rson=++tot;
add(rson,mid+1,r,p);
}
}
//void debug(int x,int l,int r)
//{
// if(l==r)
// {
// printf("%lld ",tree[x].sum);
// return;
// }
// debug(lson,l,mid);
// debug(rson,mid+1,r);
//}
ll sum(int x,int l,int r,int from,int to)
{
if(from>to)return 0;
if(!x)return 0;
if(l>=from&&r<=to)return tree[x].sum;
ll ans=0;
if(from<=mid)ans+=sum(lson,l,mid,from,to);
if(to>mid)ans+=sum(rson,mid+1,r,from,to);
return ans;
}//以上是线段树合并
void dfs2(int node)
{
// printf("%d:\n",node);
int max1=0,max2=0;
for(int hd=head[node];hd;hd=e[hd].nxt)
{
int to=e[hd].to;
if(to==pa[node])continue;
if(size[to]>max1)max2=max1,max1=size[to];
else if(size[to]>max2)max2=size[to];
}
// printf("max1=%d,max2=%d\n",max1,max2);
if(isroot[node]&&pa[node])
{
if(n-size[node]>max1)max2=max1,max1=n-size[node];
else if(n-size[node]>max2)max2=n-size[node];
if(n-size[node]==max1)ans+=1ll*node*sum(1,n-2*max2);
else ans+=1ll*node*sum(1,n-2*max1);
}
for(int hd=head[node];hd;hd=e[hd].nxt)
{
int to=e[hd].to;
if(to==pa[node])continue;
pa[to]=node;
dec(size[to]);
add(n-size[to]);
dfs2(to);
// printf("to=%d,root=%d,",to,root[to]);
// debug(root[to],1,n);
// printf("\n");
if(isroot[node])
{
// printf("%d->%d,range=",node,to);
// if(size[to]==max1)printf("%d %d,res=%lld\n",1,n-2*max2,sum(root[to],1,n,1,n-2*max2));
// else printf("%d %d,res=%lld\n",1,n-2*max1,sum(root[to],1,n,1,n-2*max1));
if(size[to]==max1)ans+=1ll*node*sum(root[to],1,n,1,n-2*max2);
else ans+=1ll*node*sum(root[to],1,n,1,n-2*max1);
}
root[node]=merge(root[node],root[to]);
add(size[to]);
dec(n-size[to]);
}
if(!isroot[node])
{
// printf("node=%d:tot=%lld-%lld\n",node,sum(max(1,n-2*size[node]),min(n,n-2*max1)),sum(root[node],1,n,max(1,n-2*size[node]),min(n,n-2*max1)));
ans+=1ll*node*sum(max(1,n-2*size[node]),min(n,n-2*max1));
ans-=1ll*node*sum(root[node],1,n,max(1,n-2*size[node]),min(n,n-2*max1));
}
if(isroot[node]&&pa[node])
{
if(n-size[node]==max1)ans-=1ll*node*sum(root[node],1,n,1,n-2*max2);
else ans-=1ll*node*sum(root[node],1,n,1,n-2*max1);
}
add(root[node]?root[node]:root[node]=++tot,1,n,size[node]);
// printf("node=%d:\n",node);
// printf("tree:");
// debug(root[node],1,n);
// printf("\n");
// printf("array:");
// for(int i=1;i<=n;i++)printf("%lld ",sum(i)-sum(i-1));
// printf("\n");
}
int t;
int main()
{
// freopen("centroid.in","r",stdin);
// freopen("centroid.out","w",stdout);
t=read();
while(t--)
{
n=read();
ans=edgenum=0;
for(int i=1;i<=tot;i++)tree[i].sum=tree[i].child[0]=tree[i].child[1]=0;
for(int i=1;i<=n;i++)s[i]=head[i]=pa[i]=size[i]=root[i]=0,isroot[i]=1;
tot=0;
for(int i=1;i<n;i++)
{
int u=read(),v=read();
add(u,v);
add(v,u);
}
dfs_pre(1);
for(int i=1;i<=n;i++)
{
if(isroot[i])
{
pa[i]=0;
dfs1(i);
for(int j=1;j<=n;j++)
if(j!=i)add(size[j]);
dfs2(i);
break;
}
}
printf("%lld\n",ans);
}
return 0;
}