题解 P2195 【HXY造公园】

· · 题解

Update 2020.08.25:修改了一些错别字与不通顺的地方

题目大意

给出一个 n 个点,m 条边组成的森林,有 q 组询问:

  1. 给出点 x,求点 x 所在的树的直径
  2. 给出点 x,y,要求将 x,y 所在的树之间连一条边并构成一棵新的树,满足这个新的树的直径最小

解题思路

首先,我们用树形DP(或bfs)求出每棵树的直径,并用并查集维护连通情况
维护 c 数组:对于每棵树的根节点 xc[x]= 该树的直径长度

接下来,对于每个询问 2(如果给出的两点在同一棵树内则忽略),利用并查集找出两棵树的根节点 x,y,并用并查集合并两棵树;合并后的树的直径则为 max\{ \lceil\frac{c[x]}{2}\rceil + \lceil\frac{c[y]}{2}\rceil +1,c[x],c[y] \},这里讲一下原因

要想直径最短,我们选择加边的点一定要在直径上,因为其他的点走到直径还要一段距离,从而增长了路径

那么直径就被选择的点分成了两段。因为我们要最小化较长的那一段,所以要让选择的点尽量靠近直径的中点。最后的答案就是 直径长度的一半向上取整

注意:后边的 c[x],c[y] 一定要加上,不然只有三十分(像我一样QwQ)
因为在合并后的树里,原来两棵树的直径还是存在的,所以新的树的直径至少是 max(c[x],c[y])

比如说下图:

c[x]=4,c[y]=0

那么,按 \lceil\frac{c[x]}{2}\rceil + \lceil\frac{c[y]}{2}\rceil +1 算出的直径应该是 3
而正确的直径该是 4,所以一定要考虑 c[x],c[y]

最后,对于每个询问 1,输出该树根节点对应的 c 数组值就可以了

AC代码

#include<cstdio>
#include<iostream>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
const int Maxn=300000+10,inf=0x3f3f3f3f;
int d[Maxn],g[Maxn];
int f[Maxn],c[Maxn]; // 这里的c数组的含义不是跟上面对应的
bool vis[Maxn];
vector <int> e[Maxn];
int n,m,q,len;
inline int read()
{
    int s=0,w=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
    while(ch>='0' && ch<='9')s=(s<<3)+(s<<1)+(ch^48),ch=getchar();
    return s*w;
}
int find(int x)
{
    if(f[x]==x)return x;
    return f[x]=find(f[x]);
}
void dfs(int x,int fa) // 树形DP求树的直径
{
    int m1=-1,m2=-1;
    for(int i=0;i<e[x].size();++i)
    {
        int y=e[x][i];
        if(y==fa)continue;
        dfs(y,x);
        int tmp=d[y]+1;
        d[x]=max(d[x],tmp);
        if(tmp>m1)m2=m1,m1=tmp;
        else if(tmp>m2)m2=tmp;
    }
    g[x]=max(max(0,m1+m2),max(m1,m2));
    len=max(len,g[x]);
}
void calc(int x) // 寻找树的直径
{
    len=0;
    dfs(x,0);
    c[x]=len;
}
int main()
{
//  freopen("in.txt","r",stdin);
    n=read(),m=read(),q=read();

    for(int i=1;i<=n;++i)
    f[i]=i;

    for(int i=1;i<=m;++i)
    {
        int x=read(),y=read();
        f[find(x)]=find(y);
        e[x].push_back(y);
        e[y].push_back(x);
    }

    for(int i=1;i<=n;++i)
    {
        if(f[i]!=i || vis[i])continue;
        calc(i);
        vis[i]=1;
    }

    while(q--)
    {
        int opt=read(),x=read();
        if(opt==1)
        {
            printf("%d\n",c[find(x)]);
            continue;
        }
        int y=read();
        x=find(x),y=find(y);
        if(x==y)continue;
        int tmp=((c[x]+1)>>1)+((c[y]+1)>>1)+1; //一个巧妙的向上取整的方法

        tmp=max(tmp,max(c[x],c[y]));

        f[find(x)]=find(y);
        c[find(x)]=tmp;
    }

    return 0;
}