Codeforces Round #641 Div1.D Slime and Biscuits 中文题解

Caro23333

2020-04-30 00:30:44

题解

Slime and Biscuits 中文题解

E_x 为在所有游戏结束时所有小饼干在 x 手里的情况,概率乘时间的和(注意此时所有情况的概率加起来并不是 1 ,所有 E_i 中概率的和才是 1 )。则答案即为 \sum\limits_{i=1}^n E_i

E'_x 为如果游戏只在所有小饼干都在 x 手里才结束时游戏的期望进行时间。

P_x 为最终游戏结束时所有小饼干在 x 手里的概率,即 E_x 这一部分对应的总概率。不难发现\sum\limits_{i=1}^n P_i = 1

再设常数C表示所有小饼干从 i 手里转移到 j 手里的期望时间(此时游戏结束条件是全集中在 j 手里,与 E'_j 的结束条件相同,且 C 的值对所有 i,j 均相等),则不难发现我们有恒等式:

E_x = E'_x-\sum\limits_{i=1}^n[i\neq x]( P_i\cdot C+E_i)

这个等式即为考虑 E'_x 中的所有情况中游戏分别在最后饼干在谁手里时结束。

将等式移项可得 :

\sum\limits_{i=1}^n E_i=E'_x-C\cdot \sum\limits_{i=1}^n[i\neq x]P_i

将等式对 x=1,2,\cdots,n 求和,可以得到 :

n\sum\limits_{i=1}^nE_i=\sum\limits_{i=1}^nE'_i-C(n-1)\sum\limits_{i=1}^nP_i

注意到 ans=\sum\limits_{i=1}^nE_i\sum\limits_{i=1}^nP_i=1 ,则得出:

n\cdot ans=\sum\limits_{i=1}^nE'_i-C(n-1)

在求 E'_xC 时我们可以注意到只关心所有饼干是否在目标人手中,所以设 f_m 为此时有 m 个小饼干在目标人手中时最后全到目标人手中的期望时间,则不难得到 f_m,f_{m+1},f_{m-1} 的等式关系:

f_m = \left\{ \begin{array}{lcl} 1 + \frac{(\sum a_i) - m}{\sum a_i}\left(\frac{1}{n-1}f_{m+1} + \frac{n-2}{n-1}f_m\right) + \frac{m}{\sum a_i}f_{m-1} & 0<m<\sum a_i \\ 1 + \frac{1}{n-1}f_{m+1} + \frac{n-2}{n-1}f_m & m=0 \\ 0 & m=\sum a_i \\ \end{array} \right.

不难在 O(\sum\limits_{i=1}^na_i) 时间内,通过消元的方法求出所有 f_n ,从而求出 E'_xC ,得出答案。

总复杂度为 O(n+\sum\limits_{i=1}^na_i)

P.S. 事实上,这种做法可能在消元过程中出现除以 0 的情况,所以我们安排了一组 hack 数据。另一种更为妥帖的做法是,将 g_i = f_i - f_{i+1} 设为未知数列方程求解,可以证明不会出现问题。标程中使用了这种做法。这种做法由 Elegia 提供。

Problem and Tutorial by he_____he

#include<bits/stdc++.h>

#define pb push_back
#define mp make_pair
#define fi first
#define se second

using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

template <typename T> bool chkmax(T &x,T y){return x<y?x=y,true:false;}
template <typename T> bool chkmin(T &x,T y){return x>y?x=y,true:false;}

int readint(){
    int x=0,f=1; char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

const int cys=998244353;
int n,m;
ll a[100005],ans[300005];

ll qpow(ll x,ll p){
    ll ret=1;
    for(;p;p>>=1,x=x*x%cys) if(p&1) ret=ret*x%cys;
    return ret;
}

int main(){
    n=readint();
    for(int i=1;i<=n;i++) a[i]=readint(),m+=a[i];
    ll invm=qpow(m,cys-2),invn1=qpow(n-1,cys-2);
    for(int i=m;i>=1;i--){
        ll k1=i*invm%cys*invn1%cys,k2=(m-i)*invm%cys;
        ans[i]=(k2*ans[i+1]+1)%cys*qpow(k1,cys-2)%cys;
    }
    for(int i=1;i<=m;i++) ans[i]=(ans[i]+ans[i-1])%cys;
    ll res=0;
    for(int i=1;i<=n;i++) res=(res+ans[m-a[i]])%cys;
    res=(res+cys-ans[m]*(n-1)%cys)%cys;
    printf("%lld\n",res*qpow(n,cys-2)%cys);
    return 0;
}