题解 P5900 【无标号无根树计数】

Karry5307

2020-03-26 10:08:57

题解

题意

n 个点的无标号无根树数量。

\texttt{Data Range:}1\leq n\leq 2\times 10^5

题解

可能是你谷最详细阐明 \texttt{Euler} 变换的题解了。

首先考虑有根树计数,设 f_nn 个点的无标号有根树数量,那么我们可以钦定某一个点为根,剩下的子树的大小加起来必须为 n-1,会发现是个背包。于是考虑使用生成函数优化一下:

f_n=[x^{n-1}]\prod\limits_{i=1}^{\infty}\left(\frac{1}{1-x^i}\right)^{f_i}

现在我们来定义 F(x)\texttt{Euler} 变换为

\mathcal{E}(F(x))=\prod\limits_{i=1}^{\infty}\left(\frac{1}{1-x^i}\right)^{f_i}

这里给出另外一种组合解释,如果 F(x)n 个点满足某种性质的无标号连通图数量,那么 \mathcal{E}(F(x))n 个点满足这种性质的无标号图的数量,这个可以像上面那样,枚举连通块的个数做背包得到,类似于有标号图计数中的 \exp

但是这个定义式是可以优化的,如果你做过付公主的背包这个题目的话应该对这种优化比较熟悉。我们设 G(x)=\mathcal{E}(F(x)),那么对定义式等式两端取对数,有

\ln G(x)=-\sum\limits_{i=1}^{\infty}f_i\ln(1-x^i)

两边求导,得到

\frac{G^\prime(x)}{G(x)}=\sum\limits_{i=1}^{\infty}f_i\frac{ix^{i-1}}{1-x^i}

做个级数展开

\frac{G^\prime(x)}{G(x)}=\sum\limits_{i=1}^{\infty}f_i\sum\limits_{j=0}^{\infty}ix^{(j+1)i-1}

积分回来

\ln G(x)=\sum\limits_{i=1}^{\infty}f_i\sum\limits_{j=1}^{\infty}\frac{x^{ij}}{j}

交换求和顺序,可以写成这个样子

\ln G(x)=\sum\limits_{j=1}^{\infty}\frac{1}{j}\sum\limits_{i=1}^{\infty}f_i(x^{j})^i

注意到右边是个点值,所以

\ln G(x)=\sum\limits_{j=1}\frac{F(x^j)}{j}

于是有

\mathcal{E}(F(x))=\exp\left(\sum\limits_{j=1}\frac{F(x^j)}{j}\right)

回到无标号有根数的计数上面来,所以我们用生成函数表示可以得到以下式子

F(x)=x\cdot\mathcal{E}(F(x))

这个可以牛迭求出,尽管是 O(n\log n) 的,但是常数非常大,实际上可能比不过 O(n\log^2n) 的分治 \texttt{FFT}

考虑同时求导并且乘上 x,也就是使用 \vartheta 算子,可以得到(计算过程给读者留做练习

\vartheta F(x)=F(x)+F(x)\left(\sum\limits_{i=1}^{\infty}x^iF^\prime(x^i)\right)

G(x)=\sum\limits_{i=1}^{\infty}x^iF^\prime(x^i),他的系数可以通过找规律目瞪口呆法发现,有

g_n=\sum\limits_{d\mid n}df_d

于是有

nf_n=f_n+\sum\limits_{i}f_{i}g_{n-i}

解出来得到

f_n=\frac{1}{n-1}\sum\limits_{i}f_{i}g_{n-i}

这个可以分治 \texttt{FFT} 解决,虽说是 O(n\log^2n) 的,但是常数很小。

接下来考虑把根的影响去掉,这个的话可以考虑枚举重心,用有根树的答案减去根不是重心的答案。

由于当树的大小是偶数的时候可能存在两个重心,所以要分类讨论一下。

n 为奇数的时候我们可以不用想那么多,由于已经钦定了根不是重心,那么必然有一棵子树使得它的大小大于 \lfloor\frac{n}{2}\rfloor,设为 k,枚举一下这个 k,答案为

f_n-\sum\limits_{k=\lfloor\frac{n}{2}\rfloor+1}f_kf_{n-k}

右边的东西可以随手卷一下。

n 为偶数的时候有可能存在两个重心,且其中一个是根,枚举并判掉。

额外减掉的方案为

\binom{f_{\frac{n}{2}}}{2}

至此,这个题目就可以用 O(n\log n) 或者 O(n\log^2n) 的时间复杂度内解决。

代码

#include<bits/stdc++.h>
#pragma GCC optimize("Ofast,unroll-loops")
using namespace std;
typedef int ll;
typedef long long int li;
typedef unsigned long long int ull;
const ll MAXN=524291,MOD=998244353,G=3,INVG=332748118;
ll n,cnt,limit; 
ll f[MAXN],g[MAXN],tmp[MAXN],tmp2[MAXN],tmp3[MAXN],rev[MAXN<<1],omgs[MAXN<<1];
inline ll read()
{
    register ll num=0,neg=1;
    register char ch=getchar();
    while(!isdigit(ch)&&ch!='-')
    {
        ch=getchar();
    }
    if(ch=='-')
    {
        neg=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        num=(num<<3)+(num<<1)+(ch-'0');
        ch=getchar();
    }
    return num*neg;
} 
inline ll qpow(ll base,ll exponent)
{
    li res=1;
    while(exponent)
    {
        if(exponent&1)
        {
            res=1ll*res*base%MOD;
        }
        base=1ll*base*base%MOD,exponent>>=1;
    }
    return res;
}
inline void setupOmg(ll cnt)
{
    ll limit=log2(cnt)-1,omg;
    omg=qpow(G,(MOD-1)>>(limit+1)),omgs[cnt>>1]=1;
    for(register int i=(cnt>>1|1);i!=cnt;i++)
    {
        omgs[i]=(li)omgs[i-1]*omg%MOD;
    }
    for(register int i=(cnt>>1)-1;i;i--)
    {
        omgs[i]=omgs[i<<1]; 
    }
}
inline void NTT(ll *cp,ll cnt,ll inv)
{
    static ull tcp[MAXN];
    register ll cur=0,x,shift=log2(cnt)-__builtin_ctz(cnt);
    if(inv==-1)
    {
        reverse(cp+1,cp+cnt);
    }
    for(register int i=0;i<cnt;i++)
    {
        tcp[rev[i]>>shift]=cp[i];
    }
    for(register int i=2;i<=cnt;i<<=1)
    {
        cur=i>>1;
        for(register int j=0;j<cnt;j+=i)
        {
            for(register int k=0;k<cur;k++)
            {
                x=tcp[j|k|cur]*omgs[k|cur]%MOD;
                tcp[j|k|cur]=tcp[j|k]+MOD-x,tcp[j|k]+=x;
            }
        }
    }
    for(register int i=0;i<cnt;i++)
    {
        cp[i]=tcp[i]%MOD;
    }
    if(inv==1)
    {
        return;
    }
    x=MOD-(MOD-1)/cnt;
    for(register int i=0;i<cnt;i++)
    {
        cp[i]=(li)cp[i]*x%MOD;
    }
}
inline void conv(ll cnt,ll *f,ll *g,ll *res)
{
    static ll tmpf[MAXN],tmpg[MAXN];
    for(register int i=0;i<cnt;i++)
    {
        tmpf[i]=f[i],tmpg[i]=g[i],res[i]=0;
    }
    ll limit=log2(cnt)-1;
    for(register int i=0;i<cnt;i++)
    {
        rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
    }
    NTT(tmpf,cnt,1),NTT(tmpg,cnt,1);
    for(register int i=0;i<cnt;i++)
    {
        res[i]=(li)tmpf[i]*tmpg[i]%MOD;
        tmpf[i]=tmpg[i]=0;
    }
    NTT(res,cnt,-1);
}
inline void calc(ll l,ll r)
{
    ll mid=(l+r)>>1;
    if(l==1)
    {
        for(register int i=l;i<=mid;i++)
        {
            tmp[i-l]=f[i],tmp2[i-l]=g[i];
        }
        ll cnt=1;
        while(cnt<=((r-l)<<1))
        {
            cnt<<=1;
        }
        conv(cnt,tmp,tmp2,tmp3);
        for(register int i=mid+1;i<=r;i++)
        {
            f[i]=(f[i]+tmp3[i-l-1])%MOD;
        }
        for(register int i=0;i<cnt;i++)
        {
            tmp[i]=tmp2[i]=tmp3[i]=0;
        }
        return;
    }
    for(register int i=l;i<=mid;i++)
    {
        tmp[i-l]=f[i];
    }
    for(register int i=1;i<=r-l;i++)
    {
        tmp2[i-1]=g[i];
    }
    ll cnt=1;
    while(cnt<=((r+mid-l-l+1)<<1))
    {
        cnt<<=1;
    }
    conv(cnt,tmp,tmp2,tmp3);
    for(register int i=mid+1;i<=r;i++)
    {
        f[i]=(f[i]+tmp3[i-l-1])%MOD;
    }
    for(register int i=0;i<cnt;i++)
    {
        tmp[i]=tmp2[i]=tmp3[i]=0;
    }
    for(register int i=l;i<=mid;i++)
    {
        tmp[i-l]=g[i];
    }
    for(register int i=1;i<=r-l;i++)
    {
        tmp2[i-1]=f[i];
    }
    conv(cnt,tmp,tmp2,tmp3);
    for(register int i=mid+1;i<=r;i++)
    {
        f[i]=(f[i]+tmp3[i-l-1])%MOD;
    }
    for(register int i=0;i<cnt;i++)
    {
        tmp[i]=tmp2[i]=tmp3[i]=0;
    }
}
inline void cdqFFT(ll l,ll r)
{
    if(l==r)
    {
        f[l]=l!=1?((li)f[l]*qpow(l-1,MOD-2)%MOD):1;
        for(register int i=l;i<=n;i+=l)
        {
            g[i]=(g[i]+(li)f[l]*l%MOD)%MOD;
        }
        return;
    }
    ll mid=(l+r)>>1;
    cdqFFT(l,mid),calc(l,r),cdqFFT(mid+1,r);
}
int main()
{
    n=read()+1,cnt=1;
    while(cnt<=n)
    {
        cnt<<=1;
    }
    setupOmg(cnt<<1),cdqFFT(1,cnt),cnt=1,limit=-1;
    while(cnt<=(n<<1))
    {
        cnt<<=1,limit++;
    }
    for(register int i=0;i<cnt;i++)
    {
        rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
        tmp[i]=i<=n?f[i]:0;
    }
    NTT(tmp,cnt,1);
    for(register int i=0;i<cnt;i++)
    {
        tmp[i]=(li)tmp[i]*tmp[i]%MOD;
    }
    NTT(tmp,cnt,-1);
    for(register int i=1;i<=n;i+=2)
    {
        g[i]=(f[i]-(li)tmp[i]*499122177%MOD+MOD)%MOD; 
    }
    for(register int i=2;i<=n;i+=2)
    {
        tmp[i]=(tmp[i]-(li)f[i>>1]*f[i>>1]%MOD+MOD)%MOD;
        g[i]=(f[i]-(li)tmp[i]*499122177%MOD+MOD)%MOD;
        g[i]=(g[i]-(li)f[i>>1]*(f[i>>1]-1)%MOD*499122177%MOD+MOD)%MOD;
    }
    printf("%d\n",g[n-1]);
}