题解 P6009 【[USACO20JAN]Non-Decreasing Subsequences P】

· · 题解

题意

给一个长度为 n,值域在 [1,k] 内的序列 aq 次询问,每次询问给 l,r,问你 a_l\sim a_r 内不降子序列的总数。

\texttt{Data Range:}n\leq 5\times 10^4,k\leq 20,q\leq 10^5

题解

考虑如何求 l=1,r=n 时的答案。

注意到值域范围很小,于是可以 \texttt{dp},设 f_{i,j} 表示前 i 个数,钦定最后一个为 j 的不降子序列总数,枚举一下得到

f_{i,j}\begin{cases}f_{i-1,j},&j\neq a_i\\f_{i-1,j}+\sum\limits_{k=1}^{j}f_{i-1,k},&j=a_i\end{cases}

注意到可以维护矩阵来进行转移,第 r 个位置的转移矩阵大概长这样:

T_{r,i,j}=[i=j]+[i\leq j][j=a_r]

这样可以求得答案为

\begin{pmatrix}1&0&\cdots&0\end{pmatrix}\prod\limits_{i=1}^{n}T_i\begin{pmatrix}1\\1\\\vdots\\1\end{pmatrix}

现在把他迁移到任意区间上来,得到

\begin{pmatrix}1&0&\cdots&0\end{pmatrix}\prod\limits_{i=l}^{r}T_i\begin{pmatrix}1\\1\\\vdots\\1\end{pmatrix}

考虑差分,得到

\begin{pmatrix}1&0&\cdots&0\end{pmatrix}\prod\limits_{i=l-1}^{1}T_i^{-1}\prod\limits_{i=1}^{r}T_i\begin{pmatrix}1\\1\\\vdots\\1\end{pmatrix}

考虑这个东西的逆矩阵是啥。根据人类智慧得到

T_{r,i,j}^{-1}=[i=j]-\frac{1}{2}[i\leq j][j=a_r]

于是维护

L_x=\prod\limits_{i=1}^{x}T_i\begin{pmatrix}1\\1\\\vdots\\1\end{pmatrix}

R_x=\begin{pmatrix}1&0&\cdots&0\end{pmatrix}\prod\limits_{i=x}^{1}T_i^{-1}

就可以得到答案为 R_{l-1}L_r

但是,暴力矩乘不可取,考虑优化。

注意到 T_r=I+A_r,并且 A_{r,i,j}=[i\leq j][j=r]

右乘和左乘都可以 O(n^2) 求出,总时间复杂度 O(nk^2+qk)

代码

#include<bits/stdc++.h>
using namespace std;
typedef int ll;
typedef long long int li;
const ll MAXN=2e5+51,MOD=1e9+7;
ll n,kk,qcnt,l,r,res;
ll x[MAXN],p[25][25],f[MAXN][25],g[MAXN][25];
inline ll read()
{
    register ll num=0,neg=1;
    register char ch=getchar();
    while(!isdigit(ch)&&ch!='-')
    {
        ch=getchar();
    }
    if(ch=='-')
    {
        neg=-1;
        ch=getchar();
    }
    while(isdigit(ch))
    {
        num=(num<<3)+(num<<1)+(ch-'0');
        ch=getchar();
    }
    return num*neg;
}
int main()
{
    n=read(),kk=read();
    for(register int i=1;i<=kk;i++)
    {
        p[i][i]=1;
    }
    for(register int i=1;i<=n;i++)
    {
        x[i]=read();
        for(register int j=1;j<=kk;j++)
        {
            for(register int k=x[i];k;k--)
            {
                p[j][x[i]]=(p[j][k]+p[j][x[i]])%MOD;
            }
        }
        for(register int j=1;j<=kk;j++)
        {
            for(register int k=1;k<=kk;k++)
            {
                f[i][j]=(f[i][j]+p[j][k])%MOD;
            }
        }
    }
    memset(p,0,sizeof(p)),qcnt=read(),g[0][1]=1;
    for(register int i=1;i<=kk;i++)
    {
        p[i][i]=1;
    }
    for(register int i=1;i<=n;i++)
    {
        for(register int j=1;j<=x[i];j++)
        {
            for(register int k=1;k<=kk;k++)
            {
                p[j][k]=(p[j][k]+(li)(5e8+3)*p[x[i]][k]%MOD)%MOD;
            }
        }
        for(register int j=1;j<=kk;j++)
        {
            g[i][j]=p[1][j];
        }
    }
    for(register int i=0;i<qcnt;i++)
    {
        l=read(),r=read(),res=0;
        for(register int j=1;j<=kk;j++)
        {
            res=(res+(li)g[l-1][j]*f[r][j]%MOD)%MOD;
        }
        printf("%d\n",res);
    }
}