题解 CF710F 【String Set Queries】

· · 题解

解法:二进制分组 + AC自动机

啥是二进制分组?

就是把所有的操作按二进制分组,每一组里面就开个数据结构维护。

怎么分组呢?

因为是在线算法,所以这个东西有点像一个叫2048的游戏。

假设之前已经分好两组,每组的大小为 4,2(大小就是里面包含了多少个操作)。

加一个操作,那么就有三组 4,2,1;如果再加一个操作,就是 4,2,1,1。

此时有两个 1,就要把那两组暴力重构,合并成一组,得到 4,2,2;再合并 ...... 最终剩余一组,大小为 8 。

可以得到,如果有 n 个操作,那么至多会分为 \log n 组,因为它分得的组数恰好就是 n 在二进制的表示下 1 的个数。(例如上面的 7 个操作被分成了 3 组)

那么暴力重构的时间复杂度呢?

考虑每个点会被重构多少次,如果把合并的过程画出来的话(如下图),你会发现这很像是一棵线段树,那么每个点至多被重构 \log n 次。

最后乘上用数据结构维护对应的复杂度就是重构的复杂度了。

回到这题,首先得会 AC自动机 并 A 了这道模板题:【模板】AC自动机(二次加强版)

然后你会发现这题反而简单点的,模板题是求每个模式串的贡献,这题只需求所有模式串的总贡献。

模板题并没有每次暴力跳 fail 边而是打了个标记,那我们也无需暴力求,像弄个前缀和一样在建树时沿着 fail 边累加就行。

对于插入操作,对它进行二进制分组,每个组建一个AC自动机,到时候查询就是把 \log n 个AC自动机 search 一遍就好了。

最后就是删除操作,可以发现贡献具有可减性

什么意思?

就是我们可以求出所有串的贡献,用它减去已删除串的贡献,得到的就是答案。

这样一来只要对插入,删除的字符串各自进行二进制分组,然后相减即可。

AC自动机的建树是 O(n) 的,所以时间复杂度是 O(len \log n)len 表示输入的总串长。

第一次写写丑了,常数有点大。

如果用了 printf 的话在后面记得加上这句话:fflush(stdout); 我也不知道为什么

code:
// 二进制分组 + AC自动机
#include <cstdio>
#include <string>
#include <cstring>
#include <queue>
using namespace std;
const int maxn = 300005;
struct Group {
    string data[maxn];
    int tot,ch[maxn][26],cnt[maxn],n;
    void insert(int r,char *s) {
        int d=s[1]-'a',len=strlen(s+1),_r=r;
        for(int i=1;i<=len;d=s[++i]-'a') {
            if(ch[r][d]==_r) {
                ch[r][d]=++tot;
                for(int j=0;j<26;j++) ch[tot][j]=_r;
            }
            r=ch[r][d];
        } ++cnt[r];
    }
    int fail[maxn];
    void build(int r) {
        fail[r]=r;
        queue<int>q;
        for(int i=0;i<26;i++)
            if(ch[r][i]>r) q.push(ch[r][i]),fail[ch[r][i]]=r;
        while(!q.empty()) {
            int u=q.front(); q.pop();
            for(int i=0;i<26;i++) {
                int v=ch[u][i];
                if(v==r) ch[u][i]=ch[fail[u]][i];
                else {
                    fail[v]=ch[fail[u]][i];
                    cnt[v]+=cnt[ch[fail[u]][i]];
                    q.push(v);
                }
            }
        }
    }
    int search(int r,char *s) {
        int d=s[1]-'a',len=strlen(s+1),res=0;
        for(int i=1;i<=len;d=s[++i]-'a')
            r=ch[r][d],res+=cnt[r];
        return res;
    }
    int stk[maxn],fr[maxn],size[maxn],N;
    void merge() {
        --N,size[N]<<=1;
        for(int i=stk[N];i<=tot;i++) {
            cnt[i]=fail[i]=0;
            for(int j=0;j<26;j++) ch[i][j]=0;
        }
        tot=stk[N];
        for(int i=0;i<26;i++) ch[tot][i]=tot;
        for(int L=fr[N];L<=n;L++) insert(stk[N],&data[L][0]);
        build(stk[N]);
    }
    void Insert(char *s) {
        stk[++N]=++tot;
        for(int i=0;i<26;i++) ch[tot][i]=tot;
        size[N]=1;
        insert(tot,s),build(stk[N]);
        int len=strlen(s+1);
        data[fr[N]=++n]=" ";
        for(int i=1;i<=len;i++) data[n]+=s[i];
        while(size[N]==size[N-1]) merge();
    }
    int Count(char *s) {
        int ans=0;
        for(int i=1;i<=N;i++)
            ans+=search(stk[i],s);
        return ans;
    }
}Add,Sub;
char tmp[maxn];
int main() {
    int T;
    scanf("%d",&T);
    for(int i=1;i<=T;i++) {
        int opt;
        scanf("%d%s",&opt,tmp+1);
        if(opt==1) Add.Insert(tmp);
        else if(opt==2) Sub.Insert(tmp);
        else printf("%d\n",Add.Count(tmp)-Sub.Count(tmp));
        fflush(stdout);
    }
    return 0;
}