题解 P5393 【下降幂多项式转普通多项式】

Rorschachindark

2020-07-20 15:46:34

题解

题目传送门

题目大意

给出一个下降幂多项式F(x)=\sum_{i=0}^{n-1} a_ix^{\underline{i}},求一个普通多项式G(x)使得G(x)=F(x)

n\le 2\times 10^5

思路

一个非常\texttt {naive}的想法就是我们直接乘上e^x转换成点值\texttt{EGF},然后再多项式快速插值插回去。

同时,我们发现点值是连续的,于是,我们就不需要多项式多点求值和多项式取模了。

最终的时间复杂度为\Theta(n\log^2n)

\texttt{Code}

#include <bits/stdc++.h>
using namespace std;

#define SZ(x) ((int)x.size())
#define Int register int
#define mod 998244353
#define ll long long
#define MAXN 2000005

int mul (int a,int b){return 1ll * a * b % mod;}
int dec (int a,int b){return a >= b ? a - b : a + mod - b;}
int add (int a,int b){return a + b >= mod ? a + b - mod : a + b;}
int qkpow (int a,int k){
    int res = 1;for (;k;k >>= 1,a = 1ll * a * a % mod) if (k & 1) res = 1ll * res * a % mod;
    return res;
}
int inv (int x){return qkpow (x,mod - 2);}

typedef vector <int> poly;

int w[MAXN],rev[MAXN];

void init_ntt (){
    int lim = 1 << 19;
    for (Int i = 0;i < lim;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << 18);
    int Wn = qkpow (3,(mod - 1) / lim);w[lim >> 1] = 1;
    for (Int i = lim / 2 + 1;i < lim;++ i) w[i] = mul (w[i - 1],Wn);
    for (Int i = lim / 2 - 1;i;-- i) w[i] = w[i << 1];
}

void ntt (poly &a,int lim,int type){
#define G 3
#define Gi 332748118
    static int d[MAXN];
    for (Int i = 0,z = 19 - __builtin_ctz(lim);i < lim;++ i) d[rev[i] >> z] = a[i];
    for (Int i = 1;i < lim;i <<= 1)
        for (Int j = 0;j < lim;j += i << 1)
            for (Int k = 0;k < i;++ k){
                int x = mul (w[i + k],d[i + j + k]);
                d[i + j + k] = dec (d[j + k],x),d[j + k] = add (d[j + k],x);
            }
    for (Int i = 0;i < lim;++ i) a[i] = d[i] % mod;
    if (type == -1){
        reverse (a.begin() + 1,a.begin() + lim);
        for (Int i = 0,Inv = inv (lim);i < lim;++ i) a[i] = mul (a[i],Inv);
    }
#undef G
#undef Gi 
}

poly operator + (poly a,poly b){
    a.resize (max (SZ (a),SZ (b)));
    for (Int i = 0;i < SZ (b);++ i) a[i] = add (a[i],b[i]);
    return a;
}

poly operator * (poly a,poly b){
    int d = SZ (a) + SZ (b) - 1,lim = 1;while (lim < d) lim <<= 1;
    a.resize (lim),b.resize (lim);
    ntt (a,lim,1),ntt (b,lim,1);
    for (Int i = 0;i < lim;++ i) a[i] = mul (a[i],b[i]);
    ntt (a,lim,-1),a.resize (d);
    return a;
}

template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}

int n,fac[MAXN],ifac[MAXN];
poly fuck,III,D[MAXN << 2],DR[MAXN << 2];

void divide (int i,int l,int r){
    if (l == r){
        DR[i].resize (2),DR[i][0] = mod - l,DR[i][1] = 1;
        D[i].resize (1),D[i][0] = (n - l) & 1 ? mod - mul (fuck[l],ifac[n - l]) : mul (fuck[l],ifac[n - l]);
        return ;
    } 
    int mid = (l + r) >> 1;divide (i << 1,l,mid),divide (i << 1 | 1,mid + 1,r);
    D[i] = DR[i << 1 | 1] * D[i << 1] + DR[i << 1] * D[i << 1 | 1];
    DR[i] = DR[i << 1] * DR[i << 1 | 1];
}

signed main(){
    init_ntt(),read (n),fuck.resize (n),III.resize (n + 1);
    fac[0] = 1;for (Int i = 1;i <= n;++ i) fac[i] = mul (fac[i - 1],i);
    ifac[n] = inv (fac[n]);for (Int i = n;i;-- i) ifac[i - 1] = mul (ifac[i],i);
    for (Int i = 0;i < n;++ i) read (fuck[i]),III[i] = ifac[i];III[n] = ifac[n];
    fuck = fuck * III,fuck.resize (n + 1);divide (1,0,n);
    for (Int i = 0;i < n;++ i) write (D[1][i]),putchar (' ');putchar ('\n');
    return 0;
}