题解 CF1349D 【Slime and Biscuits】

· · 题解

无限sto dls

题意简述

i 个人初始有 a_i 个饼干,现在进行若干次操作,每一次操作会从所有饼干中随机一块,然后把这块饼干随机分给除了原主人以外的某个人。当所有饼干集中在一个人手里时,结束操作。问操作的期望次数,对 998244353 取模。

2\le n\le 100000\ ,\ 0\le a_i\le 300000

算法分析

这里介绍一下dls发明的势能法,想法更自然,且可以作为解决这一类问题的通法。

我们考虑构造一个函数 f ,满足进行一步操作之后 \sum\limits_{i=1}^nf(a_i)-1 ,则操作的期望次数就是 初状态势函数和-末状态势函数和 ,本题即为:

\sum\limits_{i=1}^nf(a_i)-(n-1)f(0)-f(m)

考虑每一次操作,我们枚举选到的饼干是哪一个人的,再枚举这个饼干后来分给了谁,即:

\sum\limits_{i=1}^nf(a_i)=1+\sum\limits_{i=1}^n\frac{a_i}{m}\left(f(a_i-1)+\sum\limits_{j\neq i}\frac{1}{n-1}f(a_j+1)+\frac{n-2}{n-1}f(a_j)\right)

前面的 1 是我们对势函数的要求,后面第 i 个人的饼干被选中的概率是 \frac{a_i}{m},在这次操作后,第 i 个人的饼干个数会变成 a_i-1 ,再枚举这个饼干给了谁,剩下的 n-1 个人每一个人都有 \frac{1}{n-1} 的概率拿到这一块饼干,变成 a_j+1 ,其余情况不变。

我们稍微化一下这个式子:

\sum\limits_{i=1}^nf(a_i)=\sum\limits_{i=1}^n\frac{a_i}{m}(f(a_i-1)+1)+\frac{m-a_i}{m(n-1)}f(a_i+1)+\frac{(m-a_i)(n-2)}{m(n-1)}f(a_i)

这里我们把前面的 +1 放到了 \sum\limits_{i=1}^n\frac{a_i}{m}(f(a_i-1)+\color{red}1) ,不难证明这个是等价的。

这个方程看上去没啥用,但是因为我们只要能构造出一个这样的势函数就可以,因此不妨给它加强一下,让每一个 i 前后项都对应相等,即:

f(a)=\frac{a}{m}(f(a-1)+1)+\frac{m-a}{m(n-1)}f(a+1)+\frac{(m-a)(n-2)}{m(n-1)}f(a)

这样我们就可以得到 m+1 个这样的方程,下面可以用 band-matrix 或者用 f(i)=k_if(m)+b_i 的方法解出这个方程组,时间复杂度是 O(n) 或者 O(n\log \mathrm{MOD}) 的,取决于消元过程中用快速幂求逆元还是预处理出逆元。

代码实现

我写的是第二种方法,且没有预处理逆元,时间复杂度 O(n\log n)

#include<bits/stdc++.h>
using namespace std;
#define maxn 1000005
#define maxm 2000005
#define inf 0x3f3f3f3f
#define LL long long
#define ull unsigned long long
#define mod 998244353
#define eps 1e-6
#define local
void file(string s){freopen((s+".in").c_str(),"r",stdin);freopen((s+".out").c_str(),"w",stdout);}
template <typename Tp> void read(Tp &x){
    int fh=1;char c=getchar();x=0;
    while(c>'9'||c<'0'){if(c=='-'){fh=-1;}c=getchar();}
    while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+(c&15);c=getchar();}x*=fh;
}
int ksm(int b,int p,int Mod=mod){int ret=1;while(p){if(p&1)ret=1ll*ret*b%Mod;b=1ll*b*b%Mod;p>>=1;}return ret;}
int n,m;
int a[maxn],f[maxn];
struct node{//维护kf(m)+b的一个类型
    int k,b;
    node operator +(node y)const{
        return (node){(k+y.k)%mod,(b+y.b)%mod};
    }
    node operator -(node y)const{
        return (node){(k+mod-y.k)%mod,(b+mod-y.b)%mod};
    }
    node operator *(int y)const{
        return (node){1ll*k*y%mod,1ll*b*y%mod};
    }
    node operator +(int y)const{
        return (node){k,(b+y)%mod};
    }
    node operator -(int y)const{
        return (node){k,(b+mod-y)%mod};
    }
};
node ff[maxn];
signed main(){
    read(n);
    for(int i=1;i<=n;++i)read(a[i]),m+=a[i];
    ff[m]=(node){1,0};
    for(int i=m-1;i>=1;--i){//代入消元,求出每一个f(i)的 kf(m)+b 表示
        int di=1ll*(m-i)*ksm(1ll*m*(n-1)%mod,mod-2)%mod;
        int di1=1ll*i*ksm(m,mod-2)%mod;
        int di2=1ll*(m-i)*(n-2)%mod*ksm(1ll*m*(n-1)%mod,mod-2)%mod;
        ff[i-1]=ff[i]-ff[i]*di2-ff[i+1]*di;
        ff[i-1]=ff[i-1]-di1;
        ff[i-1]=ff[i-1]*ksm(di1,mod-2);
    }//边界(手算可以发现)f(0)=f(1),直接解方程
    int f0=1ll*(ff[0].b-ff[1].b)*ksm(ff[1].k-ff[0].k,mod-2)%mod;
    for(int i=0;i<=m;++i){//还原f(i)
        f[i]=(1ll*f0*ff[i].k%mod+ff[i].b)%mod;
        f[i]=(f[i]+mod)%mod;
    }
    int ans=0;//求答案
    for(int i=1;i<=n;++i)ans=(ans+f[a[i]])%mod;
    ans=(ans-1ll*(n-1)*f[0]%mod)%mod;
    ans=(ans-f[m])%mod;
    ans=(ans+mod)%mod;
    printf("%d\n",ans);
    return 0;
}