题解 P5408 【第一类斯特林数·行】

wurzang

2020-07-09 16:25:40

题解

进博客看体验更佳:https://www.luogu.com.cn/blog/id-exzang/solution-p5408

x^{\bar{n}}=\sum_{i=0}^n {\begin{bmatrix}n \\i\end{bmatrix}}x^i

其中 x^{\bar{n}} 表示上升幂,即 \prod_{i=x}^{x+n-1}i

可以直接用归纳法证明。

x^{\bar{n+1}}=(\sum_{i=0}^n \begin{bmatrix} n \\ i \end{bmatrix}x^i)(x+n) =\sum_{i=0}^{n+1} (\begin{bmatrix} n \\ i-1 \end{bmatrix} +n \times \begin{bmatrix} n\\i \end{bmatrix})x^i =\sum_{i=0}^{n+1}\begin{bmatrix}n+1\\i\end{bmatrix}x^i

至于怎么算 x^{\bar{n}}...可以采取倍增法。

显然 x^{\bar{2n}}=x^{\bar{n}}(x+n)^{\bar{n}}

假设我们已经知道了 x^{\bar{n}} ,那么瓶颈就在于如何算出 (x+n)^{\bar{n}}

f(x)=x^{\bar{n}},那么 (x+n)^{\bar{n}} 就是 f(x+n)

那么问题就变为怎么从多项式 f(x) 推到多项式 f(x+n)

假设 f(x) 的第 i 项系数为 a_i ,那么就可以推导如下:

\begin{aligned} f(x+n)&=\sum_{i=0}^{n}a_i(x+n)^i \\ &= \sum_{i=0}^{n} a_i \sum_{j=0}^i x^j n^{i-j} \binom i j \\ &= \sum_{j=0}^n x^j \sum_{i=j}^n a_i n^{i-j} \binom i j \end{aligned}

然后拆一下二项式,可以得到:

\begin{aligned} f(x+n)&=\sum_{j=0}^n \frac{x^j}{j!} \sum_{i=j}^n (a_ii!) \frac{n^{i-j}}{(i-j)!} \end{aligned}

我们设 A(i)=a_ii!,B(i)=\frac{n^i}{i!},就有:

\begin{aligned} f(x+n)&=\sum_{j=0}^n\frac{x^j}{j!} \sum_{i=j}^n A(i) \times B(i-j)\\ &= \sum_{j=0}^n \frac{x^j}{j!} \sum_{i=0}^{n-j} A(i+j) \times B(i) \\ &= \sum_{j=0}^n \frac{x^j}{j!} \sum_{i=0}^{n-j} A'(n-j-i) \times B(i) \end{aligned}

其中 A' 序列 表示 A 序列反转后的序列。

后面这个式子是一个卷积形式,可以直接 NTT 优化。

然后就可以 \mathcal{O}(n \log n) 算出 f(x+n)

至于如何计算 f(x)f(x+n) 可以也来一发 NTT

时间复杂度 T(n)=T(\frac{n}{2})+\mathcal{O}(n\log n)=\mathcal{O}(n \log n)

以下是代码:

#include<bits/stdc++.h>
using namespace std;
#define int long long 
typedef long long ll;
const ll mod=167772161;
ll G=3,invG;
const int N=1200000;
ll ksm(ll b,int n){
    ll res=1;
    while(n){
        if(n&1) res=res*b%mod;
        b=b*b%mod; n>>=1;
    }
    return res;
}
int read(){
    int x=0;char ch=getchar();
    while(!isdigit(ch))ch=getchar();
    while(isdigit(ch)) x=(x*10+(ch-'0'))%mod,ch=getchar();
    return x;
}
int tr[N];
void NTT(ll *f,int n,int fl){
    for(int i=0;i<n;++i)
        if(i<tr[i]) swap(f[i],f[tr[i]]);
    for(int p=2;p<=n;p<<=1){
        int len=(p>>1);
        ll w=ksm((fl==0)?G:invG,(mod-1)/p);
        for(int st=0;st<n;st+=p){
            ll buf=1,tmp;
            for(int i=st;i<st+len;++i)
                tmp=buf*f[i+len]%mod,
                f[i+len]=(f[i]-tmp+mod)%mod,
                f[i]=(f[i]+tmp)%mod,
                buf=buf*w%mod;
        }
    }
    if(fl==1){
        ll invN=ksm(n,mod-2);
        for(int i=0;i<n;++i)
            f[i]=(f[i]*invN)%mod;
    }
}
void Mul(ll *f,ll *g,int n,int m){
    m+=n;n=1;
    while(n<m) n<<=1;
    for(int i=0;i<n;++i)
        tr[i]=(tr[i>>1]>>1)|((i&1)?(n>>1):0);
    NTT(f,n,0);
    NTT(g,n,0);
    for(int i=0;i<n;++i) f[i]=f[i]*g[i]%mod;
    NTT(f,n,1);
}
ll inv[N],fac[N];
ll w[N],a[N],b[N],g[N];
void Solve(ll *f,int m){
    if(m==1) return f[1]=1,void(0);
    if(m&1){
        Solve(f,m-1);
        for(int i=m;i>=1;--i)
            f[i]=(f[i-1]+f[i]*(m-1)%mod)%mod;
        f[0]=f[0]*(m-1)%mod;
    }
    else{
        int n=m/2;ll res=1;
        Solve(f,n);
        for(int i=0;i<=n;++i)
            a[i]=f[i]*fac[i]%mod,b[i]=res*inv[i]%mod,res=res*n%mod;
        reverse(a,a+n+1);
        Mul(a,b,n+1,n+1);
        for(int i=0;i<=n;++i)
            g[i]=inv[i]*a[n-i]%mod;
        Mul(f,g,n+1,n+1);
        int limit=1;
        while(limit<(n+1)<<1) limit<<=1;
        for(int i=n+1;i<limit;++i) a[i]=b[i]=g[i]=0;
        for(int i=m+1;i<limit;++i) f[i]=0;
    }
}
ll f[N];
void init(int n){
    fac[0]=1;
    for(int i=1;i<=n;++i)
        fac[i]=1ll*fac[i-1]*i%mod;
    inv[n]=ksm(fac[n],mod-2);
    for(int i=n-1;i>=0;--i)
        inv[i]=1ll*inv[i+1]*(i+1)%mod;
}
signed main(){
    invG=ksm(G,mod-2);
    int n,k=0;
    cin>>n;
    init(n+n);
    Solve(f,n);
    for(int i=0;i<=n;++i)
        printf("%lld ",f[i]);
    return 0;
}