「题解」洛谷 P9089 「SvR-2」Work

· · 题解

来个时间复杂度真的线性的做法。

前面的部分和题解区那个 AC 自动机做法相同。先都 reverse 一下,找下充分条件,一个串 S 合法当且仅当它的所有前缀都存在一个后缀在 Trie 里面出现过。必要性构造就是每次删掉当前这个前缀的最短的后缀。

然后考虑根据这个必要性的构造划分等价类来 dp。

首先 g_u 表示 u 节点代表字符串在 Trie 中出现过的最短后缀,当 fail_u 存在时 g_u\gets g_{fail[u]},否则 g_u 没有真后缀,那么这个最短后缀就是它自己。

然后 f_i 表示当前这个串以 s_i 为结尾的后缀有多少个是合法的,首先 s[:i] 是合法的,然后删掉它看它前面能接上几个,就是 f_{i-g[u]},这里 us[:i] 在 Trie 中代表的节点。

注意到这里并不需要求出 AC 自动机的转移边,仅需要 Trie 边以及 fail 树。所以这样构建:

考虑 x 是由 Trie 中的父亲 fa 通过 c 转移边得到的。那么检索 y=fail_{fa},若 y 存在字符 c 的转移边,则令 x 指向 tr_{y,c};否则,令 y\gets fail_y 即不断跳 fail 检查是否有 c 出边。若到了初始节点(Trie 的根节点)依然没有 c 出边,则令 fail_x 指向初始节点。

对于每个串在 Trie 中自顶向下考虑尝试寻找 fail 的 y 在最终 fail 树中的深度变化,每次向下走一个儿子 y 深度最多 +1,每次暴力跳 fail 深度都会 -1,均摊下来总的时间复杂度就是 \mathcal{O}(\sum |s|)

所以总的时间复杂度是 \mathcal{O}(\sum |s|),并不会乘上一个字符集大小。

#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
using namespace std;
typedef pair<int,int> pii;
const int N=1000010;
int n;
int tr[N][26],fa[N],fail[N],dep[N],g[N],tot;
int head[N],ent;
struct Edge{
    int nxt,to,col;
}e[N];
int f[N];
void add(int x,int y,int c){
    tr[x][c]=y;fa[y]=x;dep[y]=dep[x]+1;
    e[++ent]={head[x],y,c};head[x]=ent;
}
void Ins(string &s){
    int u=0,l=s.size();
    for(int i=0;i<l;i++){
        if(!tr[u][s[i]-'a'])add(u,++tot,s[i]-'a');
        u=tr[u][s[i]-'a'];
    }
}
void build(){
    queue<pii>q;
    for(int i=0;i<26;i++)if(tr[0][i])q.push(mp(tr[0][i],i));
    while(!q.empty()){
        int u=q.front().fi,c=q.front().se;q.pop();
        int x=fail[fa[u]];
        while(x&&!tr[x][c])x=fail[x];
        if(x!=fa[u])fail[u]=tr[x][c];
        g[u]=fail[u]?g[fail[u]]:dep[u];
        for(int i=head[u];i;i=e[i].nxt)q.push(mp(e[i].to,e[i].col));
    }
}
string s[N];
signed main(){
    ios::sync_with_stdio(0);cin.tie(0);
    cin >> n;
    for(int i=1;i<=n;i++){
        cin >> s[i];
        reverse(s[i].begin(),s[i].end());
        Ins(s[i]);
    }
    build();
    long long ans=0;
    for(int o=1;o<=n;o++){
        int m=s[o].size(),u=0;f[0]=0;
        for(int i=0;i<m;i++){
            u=tr[u][s[o][i]-'a'];
            f[i]=f[i-g[u]]+1;
            ans+=f[i];
        }
    }
    cout << ans << '\n';
    return 0;
}