题解 CF1278F 【Cards】

NaCly_Fish

2020-01-29 03:13:18

题解

在我的博客内看体验更佳(

updated:改为 \Theta(k) 的解法。
ps:这个复杂度的做法 iostream 比我先做出来。

p=1/mq=1-p,那么枚举第一张为王牌的次数,有这么一个暴力计算的式子

\sum_{i=0}^n\binom{n}{i}p^iq^{n-i}i^k

后面那个 i^k 可以展开为第二类 Stirling 数

\sum_{i=0}^n\binom{n}{i}p^iq^{n-i}\sum_{j=1}^k \begin{Bmatrix} k \\ j \end{Bmatrix}i^{\underline j} =\sum_{j=1}^k\begin{Bmatrix} k \\ j \end{Bmatrix}\sum_{i=0}^n\binom{n}{i}p^iq^{n-i}i^{\underline j}

考虑化简后面那个式子

=\sum_{i=j}^n \binom ni p^i(1-p)^{n-i}\binom{i}{j}j! =j!\binom{n}{j}\sum_{i=j}^n\binom{n-j}{i-j}p^i(1-p)^{n-i} =n^{\underline j} \sum_{i=0}^{n-j}\binom{n-j}{i}p^{i+j}(1-p)^{n-j-i} =n^{\underline j}p^j

然后直接得到原式为

\sum_{j=0}^k\begin{Bmatrix} k \\ j \end{Bmatrix}n^{\underline j}p^j

这样就可以 \Theta(k \log k) 计算了,然而还可以更优。
暴力拆开原式的 Stirling 数得到

\sum_{j=0}^k\frac{1}{j!}\sum_{i=0}^j(-1)^{j-i}\binom{j}{i} i^k j! \binom nj p^j =\sum_{i=0}^ki^k\sum_{j=i}^k(-1)^{j-i} \binom nj \binom ji p^j

根据组合数的基本意义可以化为

\sum_{i=0}^ki^k\binom ni\sum_{j=i}^k(-1)^{j-i}\binom{n-i}{j-i}p^j

后面一个和式改为枚举从 0k-i,所有含 j 的式子都 +i 得到

\sum_{i=0}^ki^k\binom ni p^i\sum_{j=0}^{k-i}(-1)^j\binom{n-i}{j}p^j

设后面那个和式为 f(i),显然有 f(k)=1
做个差分,再做一些麻烦的推式子,就能得到如下关系:
(我的推法又臭又长,而且也没什么技术含量,就不放上来了)

f(i)=(-p)^{k-i}\binom{n-i-1}{k-i}+(1-p)f(i+1)

用线性筛求出 i^k \ (i\in [1,k]),时间复杂度就可以做到 \Theta(k)
要注意的是 n \le k 的时候会有点问题,直接用最开始的式子,暴力计算即可。

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 10000003
#define ll long long
#define p 998244353
#define reg register
using namespace std;

inline int power(int a,int t){
    int res = 1;
    while(t){
        if(t&1) res = (ll)res*a%p;
        a = (ll)a*a%p;
        t >>= 1;
    }
    return res;
}

int n,m,k,cnt;
int f[N],inv[N],pw[N],pr[N>>1],c[N];

int solve1(){
    int mul = (ll)power(p+1-m,n-1)*m%p,invq = power(p+1-m,p-2),res = 0;
    for(reg int i=1;i<=n;++i){
        res = (res+(ll)mul*c[i]%p*pw[i])%p;
        mul = (ll)mul*invq%p*m%p;
    }
    return res;
}

int solve2(){
    int c2,mul,res = 0;
    mul = c2 = f[k] = 1;
    for(reg int i=k-1;i;--i){
        c2 = (ll)c2*(n-i-1)%p*inv[k-i]%p;
        mul = (ll)mul*(p-m)%p;
        f[i] = ((ll)c2*mul+(ll)(p+1-m)*f[i+1])%p;
    }
    mul = m;
    for(reg int i=1;i<=k;++i){
        res = (res+(ll)pw[i]*c[i]%p*mul%p*f[i])%p;
        mul = (ll)mul*m%p;
    }
    return res;
}

int main(){
    scanf("%d%d%d",&n,&m,&k);
    m = power(m,p-2);
    c[0] = inv[1] = pw[1] = 1;
    c[1] = n;
    for(reg int i=2;i<=k;++i){
        inv[i] = (ll)(p-p/i)*inv[p%i]%p;
        c[i] = (ll)c[i-1]*inv[i]%p*(n-i+1)%p;
        if(!pw[i]){
            pr[++cnt] = i;
            pw[i] = power(i,k);
        }
        for(reg int j=1;j<=cnt&&(ll)i*pr[j]<=k;++j){
            pw[i*pr[j]] = (ll)pw[i]*pw[pr[j]]%p;
            if(i%pr[j]==0) break;
        }
    }
    if(n<=k+1) printf("%d",solve1());
    else printf("%d",solve2());
    return 0;   
}