题解 CF755G 【PolandBall and Many Other Balls】

Karry5307

2020-05-20 10:18:23

题解

题意

排成一行的 n 个球分成 k 个组,一组可以是一个球或者是两个相邻的球,允许有些球不在任何一组里面,求方案数。

你需要对 1\sim k 都求出答案。

\texttt{Data Range:}1\leq n\leq 10^9,1\leq k\leq 2^{15}

题解

倍增 FFT 基础题。(其实真的不难)

f_{n,k}n 个球分成 k 组的方案数,于是我们可以得到一个显然的 \texttt{dp}

f_{n,k}=f_{n-1,k}+f_{n-1,k-1}+f_{n-2,k-1}

也就是枚举最后一个球的分组情况。

然而这样子做是 O(nk) 转移的,明显无法通过。

接下来考虑另外一种脑洞大开的转移方式:

将两段长度为 ab 的球拼成长度为 a+b 的球,如果中间一段不合并的话就是

f_{a+b,k}=\sum\limits_{i=0}^{k}f_{a,i}f_{b,k-i}

如果中间合并的话就是

f_{a+b,k}=\sum\limits_{i=0}^{k-1}f_{a-1,i}f_{b-1,k-i-1}

汇总一下

f_{a+b,k}=\sum\limits_{i=0}^{k}f_{a,i}f_{b,k-i}+\sum\limits_{i=0}^{k-1}f_{a-1,i}f_{b-1,k-i-1}

套路地设 F_n(x)=\sum\limits_{k=0}^{\infty}f_{n,k}x^k,那么可以重写之前的式子:

F_{a+b}(x)=F_a(x)F_b(x)+xF_{a-1}(x)F_{b-1}(x)

F_n(x)=F_{n-1}(x)+xF_{n-1}(x)+xF_{n-2}(x)

然后是个小 \texttt{trick}

\begin{cases}F_{2n}(x)=F^2_n(x)+xF^2_{n-1}(x)\\F_{2n-1}(x)=F_n(x)F_{n-1}(x)+xF_{n-1}(x)F_{n-2}(x)\\F_{n-2}(x)=F_{n-1}^2(x)+xF_{n-2}^2(x)\end{cases}

然后注意到我们可以从 (F_{n-2}(x),F_{n-1}(x),F_{n}(x)) 转移到 (F_{2n-2}(x),F_{2n-1}(x),F_{2n}(x)),再加上线性转移,就很符合倍增 FFT 的基本条件了。

于是高高兴兴转移就得了。

代码中 \texttt{setBit}f_n\to f_{n+1},然后 \texttt{shift}f_n\to f_{2n},就跟位运算的设置最后一位和左移一样。

时间复杂度 O(n\log^2n)。(倍增一个 \log,FFT 一个 \log

到时候来写快一点的做法。

代码

#include<bits/stdc++.h>
using namespace std;
typedef int ll;
typedef long long int li;
const ll MAXN=262151,MOD=998244353,G=3,INVG=332748118;
ll cnt,kk,limit;
ll f[3][MAXN],g[5][MAXN],rev[MAXN];
inline ll read()
{
    register ll num=0,neg=1;
    register char ch=getchar();
    while(!isdigit(ch)&&ch!='-')
    {
        ch=getchar();
    }
    if(ch=='-')
    {
        neg=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        num=(num<<3)+(num<<1)+(ch-'0');
        ch=getchar();
    }
    return num*neg;
}
inline ll qpow(ll base,ll exponent)
{
    li res=1;
    while(exponent)
    {
        if(exponent&1)
        {
            res=1ll*res*base%MOD;
        }
        base=1ll*base*base%MOD,exponent>>=1;
    }
    return res;
}
inline void NTT(ll *cp,ll cnt,ll inv)
{
    ll cur=0,res=0,omg=0;
    for(register int i=0;i<cnt;i++)
    {
        if(i<rev[i])
        {
            swap(cp[i],cp[rev[i]]);
        }
    }
    for(register int i=2;i<=cnt;i<<=1)
    {
        cur=i>>1,res=qpow(inv==1?G:INVG,(MOD-1)/i);
        for(register ll *p=cp;p!=cp+cnt;p+=i)
        {
            omg=1;
            for(register int j=0;j<cur;j++)
            {
                ll t=1ll*omg*p[j+cur]%MOD,t2=p[j];
                p[j+cur]=(t2-t+MOD)%MOD,p[j]=(t2+t)%MOD;
                omg=1ll*omg*res%MOD;
            }
        }
    }
    if(inv==-1)
    {
        ll invl=qpow(cnt,MOD-2);
        for(register int i=0;i<cnt;i++)
        {
            cp[i]=1ll*cp[i]*invl%MOD;
        }
    }
}
inline void shift(ll fd)
{
    ll cnt=1,limit=-1;
    while(cnt<=(fd<<1))
    {
        cnt<<=1,limit++;
    }
    for(register int i=0;i<cnt;i++)
    {
        rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
        g[0][i]=g[1][i]=g[2][i]=g[3][i]=g[4][i]=0;
    }
    NTT(f[0],cnt,1),NTT(f[1],cnt,1),NTT(f[2],cnt,1);
    for(register int i=0;i<cnt;i++)
    {
        g[0][i]=(li)f[0][i]*f[0][i]%MOD,g[1][i]=(li)f[1][i]*f[1][i]%MOD;
        g[2][i]=(li)f[2][i]*f[2][i]%MOD,g[3][i]=(li)f[0][i]*f[1][i]%MOD;
        g[4][i]=(li)f[1][i]*f[2][i]%MOD,f[0][i]=f[1][i]=f[2][i]=0;
    }
    NTT(g[0],cnt,-1),NTT(g[1],cnt,-1),NTT(g[2],cnt,-1);
    NTT(g[3],cnt,-1),NTT(g[4],cnt,-1);
    for(register int i=0;i<cnt;i++)
    {
        f[0][i]=(g[0][i]+(i?g[1][i-1]:0))%MOD;
        f[1][i]=(g[3][i]+(i?g[4][i-1]:0))%MOD;
        f[2][i]=(g[1][i]+(i?g[2][i-1]:0))%MOD;
    }
    for(register int i=0;i<cnt;i++)
    {
        g[0][i]=g[1][i]=g[2][i]=g[3][i]=g[4][i]=0;
    }
    for(register int i=fd;i<cnt;i++)
    {
        f[0][i]=f[1][i]=f[2][i]=0;
    }
}
inline void setBit(ll fd)
{
    for(register int i=0;i<fd;i++)
    {
        f[2][i]=f[1][i],f[1][i]=f[0][i],f[0][i]=0;
    }
    f[0][0]=1;
    for(register int i=1;i<fd;i++)
    {
        f[0][i]=((f[1][i-1]+f[2][i-1])%MOD+f[1][i])%MOD;
    }
}
int main()
{
    cnt=read(),kk=read(),f[0][0]=f[1][0]=f[0][1]=1,limit=log2(cnt);
    for(register int i=limit-1;i>=0;i--)
    {
        shift(kk+10);
        if(cnt&(1<<i))
        {
            setBit(kk+10);
        }
    }
    for(register int i=1;i<=kk;i++)
    {
        printf("%d ",f[0][i]);
    }
}