[ZJOI2016] 线段树

· · 题解

先考虑 01 序列怎么做。我们统计每一个位置最终为 0 的方案数。

如果 a_i 初始是 1,那么最终不可能为 0

如果 a_i 初始是 0,那么设它左边最后一个 1 的位置为 L,右边第一个 1 的位置为 Ra_i 最终是 0 等价于 i\in(L,R)。每一次操作有可能会将 (L,R) 缩小一点。

我们对初始每一个 0 的极长连续段做一遍 dp。dp_{x,L,R} 表示 x 次操作之后,缩到 (L,R) 的方案数。转移方程:

dp_{x,L,R}=dp_{x-1,L,R}\times\left(\dfrac{L(L+1)+(n-R+1)(n-R+2)+(R-L-1)(R-L)}{2}\right) +dp_{x-1,L',R}\times L'+dp_{x-1,L,R'}\times (n-R'+1)

这个柿子显然可以使用前缀和优化,每一次 dp 的复杂度就是 O(n^2m),总复杂度也是 O(n^2m)

考虑一般序列怎么做。注意到最终序列的某个数 w 可以拆为 w=\max a_i-\sum\limits_{i=0}^{\max a_i}[w<i]

我们枚举这个 i,并把所有 \ge i 的位置标为 1,把所有 <i 的位置标为 0。此时我们就可以算出每一个位置 <i 的方案数。

可以注意到 i0 扫到 \max a_i 的过程中 01 序列状态变化次数只有 O(n) 次。此时我们得到了一个 O(n^3m) 的算法,但由于数据随机,实际复杂度是 O(n^2m) 的。一个单调递增的序列就能将这个算法卡到 O(n^3m)

但其实我们可以找出稳定 O(n^2m) 的算法。可以发现,所有的 dp 值的转移柿子都完全一致,所以我们可以把所有初始值放到同一个 dp 数组里面,然后进行一次整体的 dp,就可以求出答案。

关于为什么不拆为 w=\sum\limits_{i=0}^{\max a_i}[w>i] 来做。因为这样做非常困难,统计贡献的时候会与 L,R 的初值有关,不利于进行整体 dp。

时间复杂度 O(n^2m)

参考代码:

#include<bits/stdc++.h>
using namespace std;
#define N 405
#define MOD 1000000007
int n,m,ans[N],dp[N][N],s1[N][N],s2[N][N];bool vs[N];struct Node {int id,w;}a[N];
bool cmp(Node x,Node y) {return x.w<y.w;}
int add(int x,int y) {x+=y;return x<MOD?x:x-MOD;}
void W(int &x,int y) {x+=y;if(x>=MOD) x-=MOD;}
int qPow(int x,int y)
{int res=1;for(;y;y/=2,x=1ll*x*x%MOD) if(y&1) res=1ll*res*x%MOD;return res;}
int f(int l,int r) {return (l*(l+1)+(n-r+1)*(n-r+2)+(r-l-1)*(r-l))/2;}
int main()
{
    scanf("%d %d",&n,&m);vs[n+1]=1;
    for(int i=1;i<=n;++i) scanf("%d",&a[i].w),a[i].id=i;sort(a+1,a+n+1,cmp);
    for(int i=n,lst;i;--i)
    {
        lst=0;vs[a[i].id]=1;
        for(int j=1;j<=n+1;++j) if(vs[j]) dp[lst][j]+=a[i].w-a[i-1].w,lst=j;
    }ans[1]=1ll*a[n].w*qPow(n*(n+1)/2,m)%MOD;for(int i=2;i<=n;++i) ans[i]=ans[1];
    for(int i=1;i<=m;++i)
    {
        for(int j=0;j<=n;++j) for(int k=n+1;k>j+1;--k)
        {
            s1[j][k]=add(j?s1[j-1][k]:0,1ll*dp[j][k]*j%MOD);
            s2[j][k]=add(k<=n?s2[j][k+1]:0,1ll*dp[j][k]*(n-k+1)%MOD);
        }
        for(int j=0;j<=n;++j) for(int k=j+2;k<=n+1;++k)
        {
            dp[j][k]=1ll*dp[j][k]*f(j,k)%MOD;
            if(j) W(dp[j][k],s1[j-1][k]);if(k<=n) W(dp[j][k],s2[j][k+1]);
        }
    }for(int i=0;i<=n;++i) for(int j=i+2;j<=n+1;++j)
        for(int k=i+1;k<j;++k) W(ans[k],MOD-dp[i][j]);
    for(int i=1;i<=n;++i) printf("%d ",ans[i]);return 0;
}