题解 P3824 【[NOI2017]泳池】

· · 题解

dp神题……+过气线性代数魔法

其实重点是dp,这道题考点其实并不在常系数线性齐次递推式的化简上……

因为并没有将数据范围开到1e5所以暴力多项式取模即可通过本题

本题题解

朴素的dp

首先发现求恰好为k不太好做,变成极大子矩形小于k的概率减小于k-1的概率

然后我们设f_{i}表示这种图形出现的概率:宽为i,底部第i个点恰好是坏点,且这个i×1001的矩形中不存在大于k的极大子矩形

然后我们发现我们的答案就是\frac{f_{n+1}}{1-q}因为只需要除掉那个坏点出现的概率就可以认为前n列是任取的

那么我们仔细看一下这个f数组是可以递推的

我们枚举底部上一个坏点的位置,显然这两个坏点距离不会超过k,不然就会出现一个大于k的子矩形了

那么f_{x}应该等于这个东西

边界条件f_{0}=1

f_{x}=\sum_{i=1}^{min(k,i)}f_{x-i}p_{i}

其中p_{i}应该表示这种图形出现的概率:宽度为i的矩形,且这个i×1001里面不存在面积大于k的极大子矩形。

似乎p_{i}没有什么优秀的计算方式好像也不能递推……

所以我们考虑把P_{i}拆成一堆数的和,换句话说我们把Pdp出来

那么我们可以考虑枚举这个矩形坏点高度的最小值

所以我们设dp_{i,j}表示这种图形的出现概率:宽为i的矩形,坏点高度最小值为j+1,且这个矩形中不会出现面积大于k的极大子矩形

那么我们认真观察一下会发现dp数组是可以递推的!

我们可以从高到低的枚举坏点高度的最小值,然后从左到右枚举第一个坏点的位置进行转移,另外显然宽度为i的矩形坏点高度最小值不得超过k/i+1,所以i×j大于k的dp值我们无需也不能计算出来

那么转移方程大概长这样

dp_{i,j}=(1-q)q^{j}\sum_{t=1}^{i}(\sum_{p=j+1}^{\infty}dp_{t-1,p})(\sum_{p=j}^{\infty}dp_{i-t,p})

如果我们记sdp为这个东西(其实就是后缀和)

sdp_{i,j}=\sum_{p=j}^{\infty}dp_{i,p}

那么转移方程就是

dp_{i,j}=(1-q)q^{j}\sum_{m+n=i-1}sdp_{m,j+1}sdp_{n,j}

当然你可以用ntt加上多项式求逆均摊O(logn)的转移

但是这里暴力卷积就行了因为k只有1000

我们仔细观察一下会发现,如果我们以枚举坏点高度最小值的方式计算p的话我们会发现p大概是这个式子

p_{i}=\sum_{p=1}^{\infty}dp_{i,p}=sdp_{i,1}

所以我们的p就被求出来了……

那么此时我们的目标是求f_{n+1}

暴力计算O(nk)矩阵快速幂计算O(k^3logn)哪个好像都过不去……

然而我们仔细的想一下,我们真的需要转移矩阵的n次幂吗?

其实不是,我们只需要转移矩阵乘以初值向量的之后的向量,我们甚至不需要这个向量,我们只需要他的最后一位

下面呢就是一种处理常系数齐次线性递推式的技术了

常系数齐次线性递推式的快速计算

先解释一下我们要干什么

我们要快速计算一个“递推式”

而这个递推式满足下面几个条件

1.它是"线性"的换句话说递推式中只有常数项和一次项,且必须有一次项

2.它是"齐次"的,所有项的次数必须相等,结合它是“线性”的,我们可以知道这个递推式里没有常数项

3.它是"常系数"的,所有项的系数必须是一个常数

所以,我们大概要快速计算这个递推式的第n项

f_{n}=\sum_{i=1}^{k}f_{n-i}a_{i}

当然可以矩阵快速幂计算,复杂度是O(k^3logn)

但是我们换一个想法

如果我们可以把转移矩阵A的n次幂转化为这样的形式

至于C_{i}是什么你可以认为那是我们构造的奥妙重重的一组数,因n的不同而不同

A^{n}=\sum_{i=0}^{k-1}A^{i}c_{i}

那么我们因为要求的是初始向量St乘转移矩阵的n次幂之后的向量,我们可以在等式两边同时左乘一个St

StA^{n}=\sum_{i=0}^{k-1}StA^{i}c_{i}

由于我们要算的是向量St×A^{n}的第1项

所以刚才的等式应该对所有向量的第1项也成立

(StA^{n})_{1}=\sum_{i=0}^{k-1}c_{i}(StA^{i})_{1}

等等,我们要求St×A^{i}的第1项?

那不就是St的第i项吗?

所以我们得到了这个式子

(StA^{n})_{1}=\sum_{i=0}^{k-1}c_{i}St_{i}

所以换句话说我们只要能构造出这个奥妙重重的c我们就可以成功计算f_{n}

怎么构造呢?

我们假设说我们构造出了这样一个神奇的多项式f使得下列等式成立,这里的0 是0矩阵的意思

\sum_{i=0}^{k}f_{i}A^{i}=0

那么我们对于A^{n}可以把它表示成这种形式

其中G,C是另外两个多项式

A^{n}=G(A)f(A)+C(A)

因为刚才的等式,所以

A^{n}=C(A)=\sum_{i=0}^{k}c_{i}A^{i}

换句话说,我们只需要把A^{n}表示刚才的形式就行了,然后提取多项式C的系数就是我们需要的c了

刚才的式子好像是多项式取模的式子?

于是我们可以快速幂求出A^{n}modf(A)的值

最后一个问题,怎么求F(A)?

如果f(A)非常好求的话矩阵快速幂就没有什么存在的价值了

所以它一般来讲不是很好求,据说需要高斯消元

但是,常系数齐次线性递推的矩阵是特殊的……

所以我们的f(A)的系数可以O(1)得知

F_{k}=1

F_{k-i}=a_{i}

其中a_{i}是递推系数

所以有了这个以后我们就可以快速幂+多项式取模搞出c来

然后就可以求出f_{n}了~

多项式取模可以去洛谷模板区,当然这道题不需要ntt版的多项式取模

直接厂除法进行多项式取模就行了

上代码~

#include<cstdio>
#include<algorithm>
using namespace std;const int N=2048;typedef unsigned long long ll;const ll mod=998244353;
int n;int k;ll p;ll q;ll x;ll y;
inline ll po(ll a,ll p){ll r=1;for(;p;p>>=1,a=a*a%mod)if(p&1)r=r*a%mod;return r;}
ll sdp[N][N];ll st[N];ll ret[N];ll tr[N];ll f[N];ll cp1[N];ll cp2[N];ll a[N];
inline ll solve(int k)
{
    for(int i=0;i<=k+1;i++)sdp[i][0]=1;//dp
    for(int j=k;j>=1;j--)
        for(int i=1;i*j<=k;i++)
        {
            ll ret=0;for(int t=1;t<=i;t++)(ret+=sdp[j+1][t-1]*sdp[j][i-t])%=mod;
            ret=ret*p%mod*po(q,j)%mod;sdp[j][i]=(sdp[j+1][i]+ret)%mod;
        }
    k++;tr[1]=p;for(int i=1;i<=k-1;i++)tr[i+1]=sdp[1][i]*p%mod;//转移系数
    st[0]=1;for(int i=1;i<k;i++)for(int j=0;j<i;j++)(st[i]+=st[j]*tr[i-j])%=mod;//初值
    for(int i=1;i<=k;i++)f[k-i]=mod-tr[i];f[k]=1;ret[0]=1;a[1]=1;int t=n+1;
    while(t)//快速幂
    {
        if(t&1)
        {
            for(int i=0;i<=k;i++)cp1[i]=ret[i],ret[i]=0;
            for(int i=0;i<=k;i++)for(int j=0;j<=k;j++)(ret[i+j]+=cp1[i]*a[j])%=mod;//卷积
            for(int i=2*k;i>=k;i--)//厂除法取模
                for(int j=0;j<=k;j++)(ret[i-k+j]+=mod-ret[i]*f[j]%mod)%=mod;
        }
        for(int i=0;i<=k;i++)cp1[i]=a[i];for(int i=0;i<=k;i++)cp2[i]=a[i],a[i]=0;
        for(int i=0;i<=k;i++)for(int j=0;j<=k;j++)(a[i+j]+=cp1[i]*cp2[j])%=mod;
        for(int i=2*k;i>=k;i--)
            for(int j=0;j<=k;j++)(a[i-k+j]+=mod-a[i]*f[j]%mod)%=mod;t>>=1;
    }ll ans=0;for(int i=0;i<k;i++)(ans+=st[i]*ret[i])%=mod;
    for(int i=0;i<=k+1;i++)for(int j=0;j<=k+1;j++)sdp[i][j]=0;
    for(int i=0;i<=k;i++)a[i]=0;for(int i=0;i<=k;i++)ret[i]=0;
    for(int i=0;i<=k;i++)st[i]=0;return ans*po(p,mod-2)%mod;输出
}
int main()
{
    scanf("%d%d%lld%lld",&n,&k,&x,&y);q=x*po(y,mod-2)%mod;p=(1+mod-q)%mod;
    printf("%lld",(solve(k)+mod-solve(k-1))%mod);return 0;//拜拜程序~
}