题解 P4007 【小 Y 和恐怖的奴隶主】
有一道叫做抵制克苏恩的题,本题是那道题的加强版。
那道题的做法是期望DP。
设
则攻击1血随从、2血随从、3血随从、boss的概率分别为
转移如下:
若
若
若
然后
记录伤害期望的话,再弄一个
对于本题,上面对应的就是
直接转移显然TLE。
这个东西在
这时我们需要一个额外的状态来记录boss的扣血期望。于是最多有166个状态。
那么时间复杂度
还是会TLE。
我们发现,一个行向量乘一个方阵的复杂度是
这样的时间复杂度是
然后喜闻乐见这题卡常数,在UOJ上很可能过不了hack数据(洛谷上应该能直接过)。
有一种优化方法是减少矩阵乘法的取模次数。我每次计算的时候,直接用一个__int128
来记录中间结果,这样每次加完一遍后再取模即可。
然后由于评测机的不稳定性,还是有可能TLE。
随手加几行编译指令即可QAQ
Code:
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#include<cstdio>
#include<cstring>
#define LL long long
#define md 998244353
#define STC static_cast<LL>
#define reg register
struct matrix{
int a[167][167];
int r,c;
inline matrix(){memset(a,0,sizeof a);}
matrix operator*(const matrix&b)const{
matrix c;
c.r=r,c.c=b.c;
for(int i=0;i<r;++i)
for(reg int j=0;j<b.c;++j){
reg __int128 tmp=0;
for(reg int k=0;k<b.r;++k)
tmp+=STC(a[i][k])*b.a[k][j];
c.a[i][j]=tmp%md;
}
return c;
}
}p[62],a;
int T,m,k,num[9][9][9],inv[10];
void init3(){
int cnt=0;
for(int i=0;i<=k;++i)
for(int j=k-i;~j;--j)
for(int s=k-i-j;~s;--s)
num[i][j][s]=cnt++;
for(int i=0;i<=k;++i)
for(int j=k-i;~j;--j)
for(int s=k-i-j;~s;--s){
int&id=num[i][j][s];
const int ni=inv[i+j+s+1];
if(i)p->a[id][num[i-1][j][s]]=STC(i)*ni%md;
if(j){
if(i+j+s<k)p->a[id][num[i+1][j-1][s+1]]=STC(j)*ni%md;else
p->a[id][num[i+1][j-1][s]]=STC(j)*ni%md;
}
if(s){
if(i+j+s<k)p->a[id][num[i][j+1][s]]=STC(s)*ni%md;else
p->a[id][num[i][j+1][s-1]]=STC(s)*ni%md;
}
p->a[id][cnt]=p->a[id][id]=ni;
}
p->a[cnt][cnt]=1;
p->r=p->c=cnt+1;
a.r=1,a.c=cnt+1;
}
void init2(){
int cnt=0;
for(int i=0;i<=k;++i)
for(int j=k-i;~j;--j)
num[i][j][0]=cnt++;
for(int i=0;i<=k;++i)
for(int j=k-i;~j;--j){
int&id=num[i][j][0];
const int ni=inv[i+j+1];
if(i)p->a[id][num[i-1][j][0]]=STC(i)*ni%md;
if(j){
if(i+j<k)p->a[id][num[i+1][j][0]]=STC(j)*ni%md;else
p->a[id][num[i+1][j-1][0]]=STC(j)*ni%md;
}
p->a[id][cnt]=p->a[id][id]=ni;
}
p->a[cnt][cnt]=1;
p->r=p->c=cnt+1;
a.r=1,a.c=cnt+1;
}
int main(){
inv[1]=1;
for(int i=2;i<10;++i)
inv[i]=STC(md-md/i)*inv[md%i]%md;
scanf("%d%d%d",&T,&m,&k);
if(m==3)init3();else
init2();
for(int i=1;i<62;++i)
p[i]=p[i-1]*p[i-1];
while(T--){
LL n;
scanf("%lld",&n);
memset(*a.a,0,sizeof*a.a);
if(m==1)
a.a[0][num[1][0][0]]=1;else
if(m==2)
a.a[0][num[0][1][0]]=1;else
a.a[0][num[0][0][1]]=1;
for(int i=0;i<62;++i)
if(n>>i&1)a=a*p[i];
printf("%d\n",a.a[0][a.c-1]);
}
return 0;
}