题解 P5050 【【模板】多项式多点求值】

mrsrz

2018-12-26 18:44:16

题解

我们将要求值的点均分成两份,构造多项式P_0(x)=\prod\limits_{i=1}^{\lfloor\frac n 2\rfloor}(x-x_i)P_1(x)=\prod\limits_{i=\lfloor\frac n 2\rfloor+1}^{n}(x-x_i)

显然,对于\forall i\in[1,\lfloor\frac n 2 \rfloor],有P_0(x_i)=0P_1同理。

我们假设多项式D(x),R(x)满足:D(x)是一个n-\lfloor\frac n 2\rfloor次多项式,R(x)是一个次数不超过\lfloor\frac n 2\rfloor-1的多项式,且A(x)=P_0(x)D(x)+R(x)

那么对于\forall i\in[1,\lfloor\frac n 2 \rfloor],有A(x_i)=R(x_i)P_1同理可得。

由于R(x)的次数是A(x)的次数的一半,所以我们把一个规模为n的问题,用O(n\log n)的复杂度(多项式取模,详见多项式除法),转化为两个规模为\frac n 2的问题。

分治计算即可,由主定理得时间复杂度O(n\log^2 n)

求每次的P_0(x),P_1(x),可以开始用一次分治NTT预处理。此处时间复杂度也为O(n\log^2 n)

常数极大就是了QAQ。

Code:

#include<cstdio>
#include<algorithm>
typedef long long LL;
const int md=998244353,N=262145;
//Poly begin
int rev[N],lim,g[20][N],M,vv;
inline void upd(int&a){a+=a>>31&md;}
inline int pow(int a,int b){
    int ret=1;
    for(;b;b>>=1,a=(LL)a*a%md)if(b&1)ret=(LL)ret*a%md;return ret;
}
inline void init(int n){
    int l=-1;
    for(lim=1;lim<n;lim<<=1)++l;M=l+1;
    for(int i=1;i<lim;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<l);vv=pow(lim,md-2);
}
void NTT(int*a,int f){
    for(int i=1;i<lim;++i)if(i<rev[i])std::swap(a[i],a[rev[i]]);
    for(int i=0;i<M;++i){
        const int*G=g[i],c=1<<i;
        for(int j=0;j<lim;j+=c<<1)
        for(int k=0;k<c;++k){
            const int x=a[j+k],y=a[j+k+c]*(LL)G[k]%md;
            upd(a[j+k]+=y-md),upd(a[j+k+c]=x-y);
        }
    }
    if(!f){
        for(int i=0;i<lim;++i)a[i]=(LL)a[i]*vv%md;
        std::reverse(a+1,a+lim);
    }
}
void INV(const int*a,int*B,int n){
    if(n==1){
        *B=pow(*a,md-2);
        return;
    }
    INV(a,B,n+1>>1);
    init(n<<1);
    static int A[N];
    for(int i=0;i<n;++i)A[i]=a[i];
    for(int i=n;i<lim;++i)A[i]=0;
    NTT(A,1),NTT(B,1);
    for(int i=0;i<lim;++i)B[i]=B[i]*(2-(LL)A[i]*B[i]%md+md)%md;
    NTT(B,0);
    for(int i=n;i<lim;++i)B[i]=0;
}
void REV(int*A,int n){std::reverse(A,A+n+1);}
void MOD(const int*a,const int*b,int*R,int n,int m){
    static int A[N],B[N],D[N];
    for(int i=0;i<n<<1;++i)D[i]=0;
    for(int i=0;i<=m;++i)B[i]=b[i];
    REV(B,m);
    for(int i=n-m+1;i<n<<1;++i)B[i]=0;
    INV(B,D,n-m+1);
    init(n-m+1<<1);
    for(int i=0;i<=n-m+1;++i)A[i]=a[n-i];
    for(int i=n-m+2;i<lim;++i)A[i]=0;
    NTT(A,1),NTT(D,1);
    for(int i=0;i<lim;++i)D[i]=(LL)D[i]*A[i]%md;
    NTT(D,0);
    REV(D,n-m);
    init(n+1<<1);
    for(int i=n-m+1;i<lim;++i)D[i]=0;
    for(int i=0;i<lim;++i)A[i]=B[i]=0;
    for(int i=0;i<=m;++i)B[i]=b[i];
    for(int i=0;i<=n;++i)A[i]=a[i];
    NTT(B,1),NTT(D,1);
    for(int i=0;i<lim;++i)B[i]=(LL)B[i]*D[i]%md;
    NTT(B,0);
    for(int i=0;i<m;++i)upd(R[i]=A[i]-B[i]);
}
//Poly end
int n,m,A[N],a[N],*P[N],len[N];
void solve(int l,int r,int o){
    if(l==r){
        len[o]=1;
        P[o]=new int[2];
        upd(P[o][0]=-a[l]),P[o][1]=1;
        return;
    }
    const int mid=l+r>>1,L=o<<1,R=L|1;
    solve(l,mid,L),solve(mid+1,r,R);
    len[o]=len[L]+len[R];
    P[o]=new int[len[o]+1];
    init(len[o]+1<<1);
    static int A[N],B[N];
    for(int i=len[L];~i;--i)A[i]=P[L][i];
    for(int i=len[L]+1;i<lim;++i)A[i]=0;
    for(int i=len[R];~i;--i)B[i]=P[R][i];
    for(int i=len[R]+1;i<lim;++i)B[i]=0;
    NTT(A,1),NTT(B,1);
    for(int i=0;i<lim;++i)A[i]=(LL)A[i]*B[i]%md;
    NTT(A,0);
    for(int i=len[o];~i;--i)P[o][i]=A[i];
}
void solve(int l,int r,int o,const int*A){
    if(l==r){printf("%d\n",*A);return;}
    const int mid=l+r>>1,L=o<<1,R=L|1;
    int B[len[o]+2<<1];
    MOD(A,P[L],B,len[o]-1,len[L]);
    solve(l,mid,L,B);
    MOD(A,P[R],B,len[o]-1,len[R]);
    solve(mid+1,r,R,B);
}
int main(){
    for(int i=0;i<19;++i){
        int*G=g[i];
        G[0]=1;
        const int gi=G[1]=pow(3,(md-1)/(1<<i+1));
        for(int j=2;j<1<<i;++j)G[j]=(LL)G[j-1]*gi%md;
    }
    scanf("%d%d",&n,&m);if(!m)return 0;
    for(int i=0;i<=n;++i)scanf("%d",A+i);
    for(int i=1;i<=m;++i)scanf("%d",a+i);
    solve(1,m,1);
    if(n>=m)MOD(A,P[1],A,n,m);
    solve(1,m,1,A);
    return 0;
}