题解 P4721 【【模板】分治 FFT】

ljc1301

2019-01-14 09:59:44

题解

分治fft的模板题干嘛用多项式求逆呢

讲一下(我理解的、可以做这道题)分治fft。

其实就是cdq分治的思路。每次求一个区间的时候,要保证这个区间左边所有值对这个区间的贡献都算出来了。每次把区间分为左右两半,先算做半段,然后用ntt计算左半段对右半段的贡献,再算右半度。

举个例子g[1..3]=1, 1, 0,求f[0..3](别问我干嘛要拿这个算斐波那契数列)

刚开始,是这样的(用中括号代表要算的区间,竖线代表中间的位置)

f =[1 0|0 0]

先算左边

f =[1|0]0 0

左半边的长度为1,不往下递归。计算左区间对右区间的贡献。就是把1, 0和g的前两项0, 1做卷积,得到*, 1(星号代表我们不在意这个位置),再把得到的后半段加到这个区间的右半边。操作后:

f =[1|1]0 0

右半边的长度为1,不往下递归。这一步就好了,回到上一步。

f =[1 1|0 0]

计算左半边对右半边的贡献。把1, 1, 0, 0和g的前四项0, 1, 1, 0做卷积,得到*, *, 2, 1,再把得到的后半段加到这个区间的右半边。操作后:

f =[1 1|2 1]

现在开始计算右半段。

f = 1 1[2|1]

左半边的长度是1,不往下递归。计算左区间对右区间的贡献。就是把2, 0(注意这里是0)和g的(不是后)两项0, 1做卷积,得到*, 2,再把得到的后半段到这个区间的右半边。操作后:

f = 1 1[2|3]

右半边的长度为1,不往下递归。然后这个f数组就算好了。

代码(我是先填充成2的幂,这样好写):

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxlogn=18;
const int maxn=(1<<maxlogn)|1;
const int G0=15311432;
const int kcz=998244353;
int n,rev[maxn];
ll G[2][24],f[maxn],g[maxn],a[maxn],b[maxn];
void gcd(ll a,ll b,ll &x,ll &y)
{
    if(!b) x=1,y=0;
    else gcd(b,a%b,y,x),y-=x*(a/b);
}
inline ll inv(ll a)
{
    ll x,y;
    gcd(a,kcz,x,y);
    return (x+kcz)%kcz;
}
inline void calcrev(int logn)
{
    register int i;
    rev[0]=0;
    for(i=1;i<(1<<logn);i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(logn-1));
}
inline void FFT(ll *a,int logn,int flag)
{
    register int i,j,k,mid;
    register ll t1,t2,t;
    for(i=0;i<(1<<logn);i++)
        if(rev[i]<i)
            swap(a[rev[i]],a[i]);
    for(i=1;i<=logn;i++)
        for(mid=1<<(i-1),j=0;j<(1<<logn);j+=1<<i)
            for(k=0,t=1;k<mid;k++,t=t*G[flag][i]%kcz)
            {
                t1=a[j|k],t2=t*a[j|k|mid];
                a[j|k]=(t1+t2)%kcz,a[j|k|mid]=(t1-t2)%kcz;
            }
}
void solve(int l,int r,int logn)
{
    if(logn<=0) return;
    if(l>=n) return;
    int mid=(l+r)>>1,i;
    ll t=inv(r-l);
    solve(l,mid,logn-1); // 计算左区间
    calcrev(logn);
    memset(a+(r-l)/2,0,sizeof(ll)*(r-l)/2); // 拷贝左区间
    memcpy(a,f+l,sizeof(ll)*(r-l)/2); // 填充0
    memcpy(b,g,sizeof(ll)*(r-l)); // 拷贝g
    FFT(a,logn,0),FFT(b,logn,0); // 卷积
    for(i=0;i<r-l;i++) a[i]=a[i]*b[i]%kcz;
    FFT(a,logn,1);
    for(i=0;i<r-l;i++) a[i]=a[i]*t%kcz;
    for(i=(r-l)/2;i<r-l;i++)
        f[l+i]=(f[l+i]+a[i])%kcz; // 把卷积后的右半段的数加到f数组后半段
    // 可能你会注意到,这个卷积是(r-l)/2的长度卷一个r-l的长度,而我卷积时最终结果当作r-l的长度来存,这会不会有影响?注意到超出部分是(r-l)/2左右,根据fft的实现,超出部分是会重新从0开始填的,所以只会影响结果的前半段,与后半段无关
    solve(mid,r,logn-1); // 计算右区间
}
int main()
{
    int logn,i;
    G[1][23]=inv(G[0][23]=G0);
    for(i=22;i>=0;i--)
    {
        G[0][i]=G[0][i+1]*G[0][i+1]%kcz;
        G[1][i]=G[1][i+1]*G[1][i+1]%kcz;
    }
    scanf("%d",&n);
    for(logn=0;(1<<logn)<n;logn++);
    for(i=1;i<n;i++) scanf("%lld",&g[i]);
    for(i=0;i<n;i++) f[i]=!i;
    solve(0,1<<logn,logn);
    for(i=0;i<n;i++)
        printf("%lld ",(f[i]+kcz)%kcz);
    printf("\n");
    return 0;
}