CF1227F2 Wrong Answer on test 233 (Hard Version)

· · 题解

\text{Description}

n 道题,你的程序在上交答案时把答案交串了,第 i 个答案变成了第 i\bmod n+1 个.
给出 n 道题的正确答案以及选项数 k ,求有多少中初始上交的答案,能使交串后正确的题目数变多.

\text{Solution}

easy version

n\le2000,k\le10^9

设计 dp_{i,j} 表示填到第 i 题,多填对了 j 道题的方案数.
暴力转移即可.

hard version

n\le2\times10^5,k\le10^9

原来 dp 的做法行不通了,我们需要另辟蹊径.
f_i 表示多对了 i 道题的方案数,由于对称性,有 f_i=f_{-i}.
所以我们的答案其实就是:

\dfrac{k^n-f_0}{2}

所以我们只需要求出 f_0 就行了.
连续两个答案相同无法产生差异,我们扫一遍求出有 m 个位置可以产生差异.
枚举答对(答错)的题目数,则有:

f_0=\sum_{i=0}^{\lfloor\frac{m}{2}\rfloor}k^{n-m}\times C_m^i\times C_{m-i}^i\times (k-2)^{m-2i}

直接求解即可.

\text{Code}

dp:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define debug(...) fprintf(stderr,__VA_ARGS__)
const int N=2050;
const int mod=998244353;
inline ll read(){
  ll x(0),f(1);char c=getchar();
  while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
  while(isdigit(c)) {x=(x<<1)+(x<<3)+c-'0';c=getchar();}
  return x*f;
}
int n,m,k;
ll dp[2][N<<1];
int now,nxt,o=2000,a[N];
signed main(){
#ifndef ONLINE_JUDGE
  freopen("a.in","r",stdin);
  freopen("a.out","w",stdout);
#endif
  n=read();k=read();
  for(int i=1;i<=n;i++) a[i]=read();
  dp[nxt=1][0+o]=1;
  for(int i=1;i<=n;i++){
    swap(now,nxt);
    memset(dp[nxt],0,sizeof(dp[nxt]));
    for(int j=-i;j<=i;j++){
      if(a[i]==a[i%n+1]) (dp[nxt][j+o]+=dp[now][j+o]*k)%=mod;
      else{
    (dp[nxt][j+1+o]+=dp[now][j+o])%=mod;
    (dp[nxt][j-1+o]+=dp[now][j+o])%=mod;
    (dp[nxt][j+o]+=(k-2)*dp[now][j+o])%=mod;
      }
    }
  }
  ll res(0);
  for(int i=1;i<=n;i++) (res+=dp[nxt][i+o])%=mod;
  printf("%lld\n",res);
  return 0;
}
/*
*/

组合数:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define debug(...) fprintf(stderr,__VA_ARGS__)
const int N=2e5+100;
const int mod=998244353;
inline ll read(){
  ll x(0),f(1);char c=getchar();
  while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
  while(isdigit(c)) {x=(x<<1)+(x<<3)+c-'0';c=getchar();}
  return x*f;
}
int n,m,k;
int a[N];
ll ans;
ll ksm(ll x,ll k){
  ll res=1;
  while(k){
    if(k&1) res=res*x%mod;
    x=x*x%mod;
    k>>=1;
  }
  return res;
}
ll jc[N],ni[N],mi[N];
inline ll C(int n,int m){
  return jc[n]*ni[m]%mod*ni[n-m]%mod;
}
signed main(){
#ifndef ONLINE_JUDGE
  freopen("a.in","r",stdin);
  freopen("a.out","w",stdout);
#endif
  n=read();k=read();
  if(k==1){
    printf("0");return 0;
  }
  for(int i=1;i<=n;i++) a[i]=read();
  for(int i=1;i<=n;i++) m+=(a[i]!=a[i%n+1]);
  jc[0]=1;
  for(int i=1;i<=m;i++) jc[i]=jc[i-1]*i%mod;
  ni[m]=ksm(jc[m],mod-2);
  for(int i=m-1;i>=0;i--) ni[i]=ni[i+1]*(i+1)%mod;
  mi[0]=1;
  for(int i=1;i<=n;i++) mi[i]=mi[i-1]*(k-2)%mod;
  ans=ksm(k,n);ll w=ksm(k,n-m);
  for(int i=0;i<=m/2;i++) ans=(ans+mod-w*C(m,i)%mod*C(m-i,i)%mod*mi[m-2*i]%mod)%mod;
  ans*=ksm(2,mod-2);ans%=mod;
  printf("%lld\n",ans);
  return 0;
}
/*
*/