题解 P5205 【【模板】多项式开根】

· · 题解

看到新模板,过来做一下。

前置知识:多项式求逆 + NTT

H^2(x)\equiv F(x)(mod\ x^{\lceil\frac n2 \rceil})

G(x)\equiv H(x)(mod\ x^{\lceil\frac n2 \rceil}) G(x)-H(x)\equiv 0(mod\ x^{\lceil\frac n2 \rceil}) (G(x)-H(x))^2\equiv 0(mod\ x^{\lceil\frac n2 \rceil}) G^2(x)-2H(x)*G(x)+H^2(x)\equiv 0(mod\ x^n) F(x)-2H(x)*G(x)+H^2(x)\equiv 0(mod\ x^n) G(x)=\Large \frac {F(x)+H^2(x)}{2H(x)}

对上述式子进行多项式求逆 + NTT 即可

Code\ Below:
#include <bits/stdc++.h>
using namespace std;
const int maxn=400000+10;
const int mod=998244353;
const int inv2=499122177;
int n,f[maxn],g[maxn],A[maxn],B[maxn],C[maxn],D[maxn],r[maxn];

inline int read(){
    register int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return (f==1)?x:-x;
}

int fpow(int a,int b){
    int ret=1;
    for(;b;b>>=1,a=1ll*a*a%mod)
        if(b&1) ret=1ll*ret*a%mod;
    return ret;
}

void NTT(int *f,int n,int op){
    for(int i=0;i<n;i++)
        if(i<r[i]) swap(f[i],f[r[i]]);
    int buf,tmp,x,y;
    for(int len=1;len<n;len<<=1){
        tmp=fpow(3,(mod-1)/(len<<1));
        if(op==-1) tmp=fpow(tmp,mod-2);
        for(int i=0;i<n;i+=len<<1){
            buf=1;
            for(int j=0;j<len;j++){
                x=f[i+j];y=1ll*buf*f[i+j+len]%mod;
                f[i+j]=(x+y)%mod;f[i+j+len]=(x-y+mod)%mod;
                buf=1ll*buf*tmp%mod;
            }
        }
    }
    if(op==1) return ;
    int inv=fpow(n,mod-2);
    for(int i=0;i<n;i++) f[i]=1ll*f[i]*inv%mod;
}

void Inv(int *a,int *b,int n){
    b[0]=fpow(a[0],mod-2);
    int len,lim;
    for(len=1;len<(n<<1);len<<=1){
        lim=len<<1;
        for(int i=0;i<len;i++) A[i]=a[i],B[i]=b[i];
        for(int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)?len:0);
        NTT(A,lim,1);NTT(B,lim,1);
        for(int i=0;i<lim;i++) b[i]=((2ll-1ll*A[i]*B[i]%mod)*B[i]%mod+mod)%mod;
        NTT(b,lim,-1);
        for(int i=len;i<lim;i++) b[i]=0;
    }
    for(int i=0;i<len;i++) A[i]=B[i]=0;
    for(int i=n;i<len;i++) b[i]=0;
}

void Sqrt(int *a,int *b,int n){
    b[0]=1;
    int *A=C,*B=D,len,lim;
    for(len=1;len<(n<<1);len<<=1){
        lim=len<<1;
        for(int i=0;i<len;i++) A[i]=a[i];
        Inv(b,B,lim>>1);
        for(int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)?len:0);
        NTT(A,lim,1);NTT(B,lim,1);
        for(int i=0;i<lim;i++) A[i]=1ll*A[i]*B[i]%mod;
        NTT(A,lim,-1);
        for(int i=0;i<len;i++) b[i]=1ll*(b[i]+A[i])%mod*inv2%mod;
        for(int i=len;i<lim;i++) b[i]=0;
    }
    for(int i=0;i<len;i++) A[i]=B[i]=0;
    for(int i=n;i<len;i++) b[i]=0;
}

int main()
{
    n=read();
    for(int i=0;i<n;i++) f[i]=read();
    Sqrt(f,g,n);
    for(int i=0;i<n;i++) printf("%d ",g[i]);
    printf("\n");
    return 0;
}