AGC058F

· · 题解

被震撼到了。非常好人类智慧题,使我大小为 -1 的点疯狂旋转。

思路

注意到,如果将 \dfrac{1}{n} 换成 \dfrac{1}{n-1},则首先 n=1 时答案为 1,根据数学归纳法可以得到 n 为任意值都是 1。而这个结论是基于边数为 n-1 的。这个的组合意义为:对于每一部分,随机删除一条边并分成两个部分分别删除,最终将整个图删成 n 个独立点的概率。

可是这个题不是 \dfrac{1}{n-1}。我们尝试加点东西,显然这个图是不能随便乱加东西的,这会导致很多性质变得不同。

既然我们改变不了数量,那我们就让它从删边改成删点不就是 \dfrac{1}{n} 的系数了吗!于是,我们可以给每条边加一个点(以后称为边点),这样就是每次删除一个点了。于是组合意义变成了每次随机删除一个点,将树变成两个部分分别删除,最终将整棵树删完的概率(显然也是 1)。可惜的是,点数变成了 2\times n-1,点数变化了,这不是我们想要的结果。需要注意的是,我们钦定了分别删除两个部分,也就是钦定了顺序,所以原来树上的点和新加的点的删除顺序不会影响答案。

这个时候厉害的来了,我们给每个边点连 -1 个点,这样点数就是 n 了。如果不好理解,可以看作在模 P 意义下连 P-1 个点。因为它的组合意义所以可以看作是等价的。

于是我们解决了 \dfrac{1}{n} 的问题。现在可以将这个题看作概率问题了。但是我们会在删除边点的同时删除原来树上的点,这个很让人头疼,因为我们如果删除了原树上的边,但此时周围的边点没有删,那么就相当于在原问题上删除了一点,这是不合理的。

于是,我们需要钦定所有边点的删除顺序先于周围的所有点,这样每次删除的时候所有分出来的树的大小都与原来树上对应部分的大小模 P 意义下相等。那么原问题就变成了随机删除一个点,最后所有边点删除先于周围所有点的概率。

我们给这个顺序建立拓补图,发现有的限制是儿子先于父亲,有的限制是父亲先于儿子,很不好算。考虑容斥,钦定每个儿子先于父亲的限制变成父亲先于儿子或者不限制,这下所有边的方向就相同了。

f_{i,j} 表示点 i 所属子树钦定有边的外向树大小为 j 时满足要求的概率,若当前点是原树上的点,则每个儿子必然都是边点,要么选择删掉,方案为 f_{u,i}=f_{u,i}\times \sum_{j=1}^{siz_v} f_{v,j},要么选择变成上到下,方案为 f_{u,i+j}=f_{u,i+j}-f_{u,i}\times f_{v,j}。因为你钦定了一条边反向,所以要乘上容斥系数 -1。如果是边点,则每条限制都是向下的,则 f_{u,i+j}=f_{u,i+j}+f_{u,i}\times f_{v,j}(注意此时并没有反向一个限制所以不用乘容斥系数 -1)。

最后,每个点在外向树内必然是第一个被删除的,即 f_{u,i}=f_{u,i}\times \dfrac{1}{i}

可是每个边点还有 P-1 个额外点,所以边点的转移是困难的。需要注意的是,当前边点加上这些点 siz0,且这些的 dp 值都是 f_{u,1}=1。而每个边点只会连接一个原树上的儿子。所以这个边点最终的值就是对于那个唯一原树上的儿子,f_{u,i}=f_{v,i}\times \dfrac{1}{i}

根据容斥,最终答案为 \sum_{i=1}^{n}f_{1,i}

使用树形背包的合并方式(即只卷到 siz 大小)是 O(n^2) 的。

代码

#include<bits/stdc++.h>
using namespace std;
const int NN=5004,P=998244353;
vector<int>g[NN];
int f[NN][NN],siz[NN],inv[NN];
void dfs(int u,int fa)
{
    f[u][1]=1;
    siz[u]=1;
    for(auto v:g[u])
    {
        if(v==fa)
            continue;
        dfs(v,u);
        for(int i=siz[u];i;i--)
        {
            int res=0;
            for(int j=1;j<=siz[v];j++)
            {
                f[u][i+j]=(f[u][i+j]-1ll*f[u][i]*f[v][j]%P*inv[j]%P+P)%P;
                res=(res+1ll*f[v][j]*inv[j]%P)%P;
            }
            f[u][i]=1ll*f[u][i]*res%P;
        }
        siz[u]+=siz[v];
    }
    for(int i=1;i<=siz[u];i++)
        f[u][i]=1ll*f[u][i]*inv[i]%P;
}
int main()
{
    int n;
    scanf("%d",&n);
    inv[1]=1;
    for(int i=2;i<=n;i++)
        inv[i]=1ll*(P-P/i)*inv[P%i]%P;
    for(int i=1;i<n;i++)
    {
        int u,v;
        scanf("%d%d",&u,&v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1,0);
    int res=0;
    for(int i=1;i<=n;i++)
        res=(res+f[1][i])%P;
    printf("%d",res);
    return 0;
}