题解 AT2070 【[ARC061D] 3人でカードゲーム / Card Game for Three】

· · 题解

题意 : 有三堆牌,分别有 n_1,n_2,n_3 张。牌上写着数字

先从牌堆 $1$ 中抽一张,接下来,牌上写着几就从几号牌堆抽取。 求在所有可能的 $3^{n_1+n_2+n_3}$ 种方案中,先把牌堆 $1$ 抽空的方案数。 答案对 $10^9+7$ 取模。 $n_1,n_2,n_3\leq 3\times 10^5$ ,时限 $\texttt{3s}$。 ------------ 条件计数的路子无非两条,要么寻找更简洁的充要条件,要么容斥。 题目中的过程比较精巧,考虑寻找充要条件。 把抽取出的牌排成一个序列,显然,每种放置方式都恰好对应一个序列。(构造映射) 但是,由于可能拿不完牌,所以一个序列可能对应很多种方案,具体地,一个长为 $m$ 的序列对应 $3^{n_1+n_2+n_3-m}$ 种方案。(检查反映射) 我们思考对这个序列的约束,可以发现,除了率先将堆 $1$ 拿空之外,没有任何约束。(检查充要条件) 于是,问题就变成了 : 对每个长度,求先将堆 $1$ 拿空的序列个数。 显然,操作序列中一定恰有 $n_1$ 个 $1$ ,且最后一个必须是 $1$。 枚举抽出的非 $1$ 牌个数 $k$ ,方案数为 $\dbinom{k+n_1-1}{k}\sum\limits_{i=0}^k[i\leq n_2][k-i\leq n_3]\dbinom{k}{i}

解释一下,\binom{k+n_1-1}{k} 表示 n_1-1 个自由的 1 与非 1 混合的方案数,后面的求和是瓜分非 1 牌的方案数。

但糟糕的是,后半部分是一个组合数部分和,这似乎没有什么快速的方法分别求解,考虑递推。

S(k)=\sum\limits_{i=0}^k[i\leq n_2][k-i\leq n_3]\dbinom{k}{i} =\sum\limits_{k-n_3\leq i\leq n_2}\dbinom{k}{i}

将组合数裂开 ;

=\sum\limits_{k-n_3\leq i\leq n_2}\dbinom{k-1}{i}+\dbinom{k-1}{i-1} =\sum\limits_{k-n_3\leq i\leq n_2}\dbinom{k-1}{i}+\sum\limits_{k-n_3-1\leq i\leq n_2-1}\dbinom{k-1}{i} =2S(k-1)-\dbinom{k-1}{k-n_3-1}-\dbinom{k-1}{n_2}

注意组合数可能不合法,此时值为 0

求出各个 S(k) 之后,答案即为下式 :

\sum\limits_{k=0}^{n_2+n_3}3^{n_1+n_2+n_3-k}\dbinom{n_1-1+k}{k}S(k)

不看题解玩出来还是有点小激动的……

#include<algorithm>
#include<cstdio>
#define ll long long
#define MaxN 900500
using namespace std;
const int mod=1000000007;
ll powM(ll a,int t=mod-2){
  ll ret=1;
  while(t){
    if (t&1)ret=ret*a%mod;
    a=a*a%mod;t>>=1;
  }return ret;
}
ll fac[MaxN],ifac[MaxN];
ll C(int n,int m){
  if (m<0||n<m)return 0;
  return fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
void Init(int n)
{
  fac[0]=1;
  for (int i=1;i<=n;i++)
    fac[i]=fac[i-1]*i%mod;
  ifac[n]=powM(fac[n]);
  for (int i=n;i;i--)
    ifac[i-1]=ifac[i]*i%mod;
}
void preS(int n2,int n3,ll *S)
{
  S[0]=1;
  for (int k=1;k<=n2+n3;k++)
    S[k]=(2*S[k-1]-C(k-1,k-1-n3)-C(k-1,n2))%mod;

}
int n1,n2,n3,N;
ll S[MaxN],pw3[MaxN];
int main()
{
  scanf("%d%d%d",&n1,&n2,&n3);
  Init(N=n1+n2+n3);
  preS(n2,n3,S);
  ll ans=0,buf=powM(3,N-n1),sav=powM(3);
  for (int k=0;k<=n2+n3;k++){
    ans=(ans+buf*C(n1+k-1,k)%mod*S[k])%mod;
    buf=buf*sav%mod;
  }printf("%lld",(ans+mod)%mod);
  return 0;
}