题解 P5373 【【模板】多项式复合函数】

yurzhang

2019-06-02 12:37:19

题解

前言

这是这道题复杂度比较正确的一个常数奇大无比的算法,源自 \text{R.P.BRENT}\text{H.T.KUNG}1978 年发表的论文(也就是题面里说的那个全嘤文pdf),我在 \text{5月15日} 蒯到手之后肝了一个晚上,在神犇 rqy 的帮助下成功地写出了66分点名被卡做法awsl,后来研究别人代码的时候发现了 预处理原根 这种神奇操作,于是成功AC,特来还愿。

正文

这种做法基于对外层函数的泰勒展开,我们把内层函数 G 的前 m 项拆出来,记作 G_m ,把剩下的记作 G_r 。(这个 m 到底取多少我们分析复杂度的时候再确定)

然后进行泰勒展开:

\large F(G)=F(G_m+G_r)=F(G_m)+F'(G_m)G_r+\frac12F''(G_m)G_r^2+\cdots

由于我们只需要求 F(G)n+1 项的值,所以我们只需要知道这个展开式前 \lceil\frac nm\rceil 项的值即可,即:
l=\lceil\frac nm\rceil ,有

\large F(G(x))\equiv F(G_m(x))+F'(G_m(x))G_r(x)+\cdots+\frac1{l!}F^{(l)}(G_m(x))G_r^l(x)\pmod{x^{n+1}}

因此我们只要计算出 F(G_m(x)) 的各阶导和 G_r(x) 的各次幂,就可以累加得到 F(G(x)) 了。

考虑外层函数最高次项次数为 $2$ 的幂的情况,即: 令 $F(x)=f_0+\cdots+f_jx^j$ 且 $j$ 为 $2$ 的正整数次幂,有 $$\large F(G)=F_1(G)+G^{\frac j2}\cdot F_2(G)$$ 这里 $F_1$ 和 $F_2$ 都是最高次项次数为 $\frac j2$ 的多项式,这样一来我们就能递归地计算 $F(G_m(x))$ 了。 考虑这一步的时间复杂度: 我们令 $M(n)$ 为完成两个 $n$ 次多项式的乘法的时间,即: $M(n)=O(n\log n)

我们设 T(j) 为计算 G^{\frac j2}F(G) 的时间,则有

\large T(j)\leqslant2T(\frac j2)+O(M(\min(jm,n)))

我们令 r 为满足 n\cdot2^k\leqslant jm 的最大的 k ,则有

\large T(j)=O(M(n)+2M(n)+\cdots+2^rM(n))+2^{r+1}T(\frac{j}{2^{r+1}}) \large T(j)\leqslant O(\frac{jmM(n)}{n})+(\frac{2jm}{n})T(\frac{j}{2^{r+1}})

又因为 n\cdot2^{r+1}>jm ,有

\large T(\frac{j}{2^{r+1}})=O(M(\frac{jm}{2^{r+1}})+2M(\frac{jm}{2^{r+2}})+\cdots) \large=O(M(n)+2M(\lceil\frac n2\rceil)+4M(\lceil\frac n4\rceil)+\cdots) \large=O(\log nM(n))

因此,我们有 T(j)=O(\frac{jm\log n}{n}M(n)) ,所以求解 F(G_m(x)) 的时间复杂度为 O(mn\log^2n)

我们再来考虑对 F(G_m(x)) 求导:
H(x)=F(G(x))
根据复合函数求导法则,我们有 H'(x)=F'(G(x))\cdot G'(x) ,因此 F'(G_m(x))=H'(x)\cdot(G_m'(x))^{-1} ,由此我们可以计算出泰勒展开式中某一项之后与之前的结果累加得到 H(x) ,同时计算出下一项中 F(G_m(x)) 的对应阶导。每计算一次导数的时间复杂度是 O(n\log n) 的,因此计算出所有项的时间复杂度是 O(l\cdot n\log n)=O(\frac{n^2\log n}{m})

我们已经得到了 F(G_m(x)) 的各阶导,而 G_r(x) 的各次幂只需要挨个乘起来即可,时间复杂度也是 O(\frac{n^2\log n}{m}) 的。

之后各项的计算和累加同样是 O(\frac{n^2\log n}{m}) 的,因此这个算法的总复杂度为 O(mn\log^2n+\frac{n^2\log n}{m})

我们考虑 m 的取值:根据均值不等式,当 mn\log^2n\sim\frac{n^2\log n}{m} 时时间复杂度最优,因此解得当 m\sim\sqrt{\frac{n}{\log n}} 时有最优时间复杂度 O((n\log n)^{1.5})

最后

还是附上参考代码比较好,这个东西由于常数原因必须 预处理原根 才能卡过去,而且好像也没有什么实际用处。。就当是练习码力了 2333

#pragma GCC optimize("Ofast,inline")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,sse4.1,sse4.2,popcnt,abm,mmx,avx,avx2,tune=native")
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

#define MOD 998244353
#define G 332748118
#define N 262210
#define re register
#define gc pa==pb&&(pb=(pa=buf)+fread(buf,1,100000,stdin),pa==pb)?EOF:*pa++
typedef long long ll;
static char buf[100000],*pa(buf),*pb(buf);
static char pbuf[3000000],*pp(pbuf),st[15];
inline int read() {
    re int x(0);re char c(gc);
    while(c<'0'||c>'9')c=gc;
    while(c>='0'&&c<='9')
        x=x*10+c-48,c=gc;
    return x;
}
inline void write(re int v) {
    if(v==0)
        *pp++=48;
    else {
        re int tp(0);
        while(v)
            st[++tp]=v%10+48,v/=10;
        while(tp)
            *pp++=st[tp--];
    }
    *pp++=32;
}

inline int pow(re int a,re int b) {
    re int ans(1);
    while(b)
        ans=b&1?(ll)ans*a%MOD:ans,a=(ll)a*a%MOD,b>>=1;
    return ans;
}

int inv[N],ifac[N];
inline void pre(re int n) {
    inv[1]=ifac[0]=1;
    for(re int i(2);i<=n;++i)
        inv[i]=(ll)(MOD-MOD/i)*inv[MOD%i]%MOD;
    for(re int i(1);i<=n;++i)
        ifac[i]=(ll)ifac[i-1]*inv[i]%MOD;
}

inline int getLen(re int t) {
    return 1<<(32-__builtin_clz(t));
}

int lmt(1),r[N],w[N];
inline void init(re int n) {
    re int l(0);
    while(lmt<=n)
        lmt<<=1,++l;
    for(re int i(1);i<lmt;++i)
        r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    re int wn(pow(3,(MOD-1)/lmt));
    w[lmt>>1]=1;
    for(re int i((lmt>>1)+1);i<lmt;++i)
        w[i]=(ll)w[i-1]*wn%MOD;
    for(re int i((lmt>>1)-1);i;--i)
        w[i]=w[i<<1];
}

inline void DFT(int*a,re int l) {
    static unsigned long long tmp[N];
    re int u(__builtin_ctz(lmt)-__builtin_ctz(l)),t;
    for(re int i(0);i<l;++i)
        tmp[i]=(a[r[i]>>u])%MOD;
    for(re int i(1);i<l;i<<=1)
        for(re int j(0),step(i<<1);j<l;j+=step)
            for(re int k(0);k<i;++k)
                t=(ll)w[i+k]*tmp[i+j+k]%MOD,
                tmp[i+j+k]=tmp[j+k]+MOD-t,
                tmp[j+k]+=t;
    for(re int i(0);i<l;++i)
        a[i]=tmp[i]%MOD;
}

inline void IDFT(int*a,re int l) {
    std::reverse(a+1,a+l);DFT(a,l);
    re int bk(MOD-(MOD-1)/l);
    for(re int i(0);i<l;++i)
        a[i]=(ll)a[i]*bk%MOD;
}

int n,m;
int a[N],b[N],c[N];

void getInv(int*a,int*b,int deg) {
    if(deg==1)
        b[0]=pow(a[0],MOD-2);
    else {
        static int tmp[N];
        getInv(a,b,(deg+1)>>1);
        re int l(getLen(deg<<1));
        for(re int i(0);i<l;++i)
            tmp[i]=i<deg?a[i]:0;
        DFT(tmp,l),DFT(b,l);
        for(re int i(0);i<l;++i)
            b[i]=(2ll-(ll)tmp[i]*b[i]%MOD+MOD)%MOD*b[i]%MOD;
        IDFT(b,l);
        for(re int i(deg);i<l;++i)
            b[i]=0;
    }
}

inline void getDer(int*a,int*b,int deg) {
    for(re int i(0);i+1<deg;++i)
        b[i]=(ll)a[i+1]*(i+1)%MOD;
    b[deg-1]=0;
}

void getComp(int*a,int*b,int k,int m,int&n,int*c,int*d) {
    if(k==1) {
        for(re int i(0);i<m;++i)
            c[i]=0,d[i]=b[i];
        n=m,c[0]=a[0];
    } else {
        static int t1[N],t2[N];
        int nl(n),nr(n),*cl,*cr,*dl,*dr;
        getComp(a,b,k>>1,m,nl,cl=c,dl=d);
        getComp(a+(k>>1),b,(k+1)>>1,m,nr,cr=c+nl,dr=d+nl);
        n=std::min(n,nl+nr-1);
        re int _l(getLen(nl+nr));
        for(re int i(0);i<_l;++i)
            t1[i]=i<nl?dl[i]:0;
        for(re int i(0);i<_l;++i)
            t2[i]=i<nr?cr[i]:0;
        DFT(t1,_l),DFT(t2,_l);
        for(re int i(0);i<_l;++i)
            t2[i]=(ll)t1[i]*t2[i]%MOD;
        IDFT(t2,_l);
        for(re int i(0);i<n;++i)
            c[i]=((i<nl?cl[i]:0)+t2[i])%MOD;
        for(re int i(0);i<_l;++i)
            t2[i]=i<nr?dr[i]:0;
        DFT(t2,_l);
        for(re int i(0);i<_l;++i)
            t2[i]=(ll)t1[i]*t2[i]%MOD;
        IDFT(t2,_l);
        for(re int i(0);i<n;++i)
            d[i]=t2[i];
    }
}

inline void getComp(int*a,int*b,int*c,int deg) {
    static int ts[N],ps[N],c0[N],_t1[N],idM[N];
    int M(std::max((int)ceil(sqrt(deg/log2(deg))*2.5),2)),_n(deg+deg/M);
    getComp(a,b,deg,M,_n,c0,_t1);
    re int _l(getLen(_n+deg));
    for(re int i(_n);i<_l;++i)
        c0[i]=0;
    for(re int i(0);i<_l;++i)
        ps[i]=i==0;
    for(re int i(0);i<_l;++i)
        ts[i]=M<=i&&i<deg?b[i]:0;
    getDer(b,_t1,M);
    for(re int i(M-1);i<deg;++i)
        _t1[i]=0; /// Important!!!
    getInv(_t1,idM,deg);
    for(int i=deg;i<_l;++i)
        idM[i]=0;
    DFT(ts,_l),DFT(idM,_l);
    for(re int t(0);t*M<deg;++t) {
        for(re int i(0);i<_l;++i)
            _t1[i]=i<deg?c0[i]:0;
        DFT(ps,_l),DFT(_t1,_l);
        for(re int i(0);i<_l;++i)
            _t1[i]=(ll)_t1[i]*ps[i]%MOD,
            ps[i]=(ll)ps[i]*ts[i]%MOD;
        IDFT(ps,_l),IDFT(_t1,_l);
        for(re int i(deg);i<_l;++i)
            ps[i]=0;
        for(re int i(0);i<deg;++i)
            c[i]=((ll)_t1[i]*ifac[t]+c[i])%MOD;
        getDer(c0,c0,_n);
        for(re int i(_n-1);i<_l;++i)
            c0[i]=0;
        DFT(c0,_l);
        for(re int i(0);i<_l;++i)
            c0[i]=(ll)c0[i]*idM[i]%MOD;
        IDFT(c0,_l);
        for(re int i(_n-1);i<_l;++i)
            c0[i]=0;
    }
}

int main() {
    n=read(),m=read();
    for(re int i(0);i<=n;++i)
        a[i]=read();
    for(re int i(0);i<=m;++i)
        b[i]=read();

    m=(n>m?n:m)+1;
    pre(m);init(m*5);
    getComp(a,b,c,m);

    for(re int i(0);i<=n;++i)
        write(c[i]);
    fwrite(pbuf,1,pp-pbuf,stdout);
    return 0;
}