题解 P4156【[WC2016]论战捆竹竿】

· · 题解

看到初二的学弟已经开始卷这些了,非常恐怖。这就是我校初二真实实力吗?

P4156 [WC2016]论战捆竹竿解题报告:

更好的阅读体验

题意

给定一个长度为 n 的字符串 S,一开始有一个空串,每次可以将一个新的 S 去掉一个 Border 后拼上去,求长度区间 [n,w] 内有多少长度可以被拼出来。(1\leqslant n\leqslant 5\times 10^5,1\leqslant 10^{18}

分析

妙妙题。

那么我们用 kmp 求出所有 Border,那么就可以得到所有周期 c_{1\cdots k},我们的问题转化为求 \sum_{i=1}^k a_ic_i 在区间 [0,n-w] 中可以有多少取值。

我们考虑随便取一个不小于所有 c_i 的模数 mod,对于所有 i\in[0,mod-1],求出最小的 k 使得 k\equiv i\pmod{mod},然后就可以直接算了。

这明显是同余最短路模型,我们如果直接取所有 c_i 的最大值为模数,让所有点 x 都连若干条边到 x+c_1,x+c_2,\cdots,x+c_k 跑一遍最短路复杂度就是 O(n^2),有暴力分。

我们肯定不会把所有边建出来,于是我们考虑另一种暴力,对于每个 c_i 单独将它的边更新,每次继承上一次的最短路数组。

但是如何更新仍然是问题,不难发现对于串长 len 最短路图可以分成 \gcd(n,len) 个部分,每个部分之间没有边相连,且每个部分事实上就是从每个位置开始每次向后走 len 步的所有点,那么我们就可以把整个图变成若干个部分,每个部分从最小的位置一步步更新到后面直到重合就好了,但这样仍然是 O(n^2) 的。

事实上,一个简单的结论可以大大降低复杂度:一个字符串所有的 Border 排序后,长度会形成 O(\log n) 个等差数列。

这启发我们将等差数列放在一起跑最短路,具体地,我们发现对于一个首项为 len,公差为 dif,项数为 tot 的等差数列,我们并不好处理首项,于是可以直接把首项当成一个公差为 len,项数为 1 的等差数列的公差处理。

此时不难发现忽略了首项长度后,一个长度为 tot 的等差数列就是向前连 tot 条边,这个东西可以直接单调队列维护。

代码也很好写,时间复杂度:O(Tn\log n)

代码

#include<stdio.h>
#include<iostream>
using namespace std;
const int maxn=500005;
int T,n,cs,pos;
int p[maxn],c[maxn];
long long w,ans;
long long dis[maxn],tmp[maxn];
pair<int,long long>q[maxn];
string s; 
int gcd(int a,int b){
    return b==0? a:gcd(b,a%b);
}
void getborder(){
    cs=0,p[0]=0;
    for(int i=1;i<n;i++){
        int j=p[i-1];
        while(j&&s[i]!=s[j])
            j=p[j-1];
        if(s[i]==s[j])
            j++;
        p[i]=j;
    }
    for(int i=p[n-1];i;i=p[i-1])
        c[++cs]=n-i;
    c[++cs]=n;
}
void getdis(int newpos,int len,int dif,int tot){
    int mod=c[newpos],g=gcd(dif,mod);
    for(int i=0;i<g;i++){
        int st=i;
        for(int j=(i+dif)%mod;j!=i;j=(j+dif)%mod)
            if(dis[j]<dis[st])
                st=j;
        int l=1,r=0;
        q[++r]=make_pair(0,dis[st]);
        for(int j=(st+dif)%mod,k=1;j!=st;j=(j+dif)%mod,k++){
            while(l<=r&&k-q[l].first>tot)
                l++;
            dis[j]=min(dis[j],q[l].second+1ll*(k-q[l].first)*dif+len);
            while(l<=r&&q[r].second+1ll*(k-q[r].first)*dif>=dis[j])
                r--;
            q[++r]=make_pair(k,dis[j]);
        }
    }
    pos=newpos;
}
void solve(){
    for(int i=0;i<c[1];i++)
        dis[i]=w+1;
    pos=1,dis[n%c[1]]=n;
    int l=1,r;
    while(l<cs){
        for(r=l+1;r<=cs&&c[r]-c[r-1]==c[l+1]-c[l];r++);
        r--;
        if(l<r)
            getdis(l,c[l],c[l+1]-c[l],r-l);
        if(r+1<=cs){
            for(int i=0;i<c[r+1];i++)
                tmp[i]=w+1;
            for(int i=0;i<c[l];i++)
                tmp[dis[i]%c[r+1]]=min(tmp[dis[i]%c[r+1]],dis[i]);
            for(int i=0;i<c[r+1];i++)
                dis[i]=tmp[i];
            getdis(r+1,0,c[l],1);
        }
        l=r+1;
    }
}
int main(){
    scanf("%d",&T);
    while(T--){
        scanf("%d%lld",&n,&w),cin>>s;
        getborder(),solve();
        ans=0;
        for(int i=0;i<c[pos];i++)
            if(dis[i]<=w)
                ans+=(w-dis[i])/c[pos]+1;
        printf("%lld\n",ans);
    }
    return 0;
}