AC 自动机 ACAM 学习笔记

· · 算法·理论

感觉这个东西就是对于 KMP 的生搬硬套,Fail 指针的匹配策略完全可以用 KMP 理解,但是这不妨碍 AC 自动机是一个强大的自动机。

当然还是不太一样的/tiao。

如果字符串的匹配,模式串有多个,那么考虑建 Trie 树,这样的好处是字符串之间融为一体又相互独立,然后我们考虑在树上建一个 Fail 指针,如果失配了,就跳到 Fail 指针的位置。

这和 KMP 非常相似,但是我们匹配的过程实际上是多个字符串同时匹配的,减少了时间复杂度。

随便说一个性质:

构建 Fail 指针

类似于 KMP,假设当前字符为 c 找到父亲的 Fail 指针所指之处,看是否存在 \to c 的边。

这个就很像 KMP 了。

实现上,注意到一个重要的性质,就是 Fail 指针的深度不会比我深,所以按层数递增顺序遍历的 BFS 比较好。

然后我看见 OI Wiki 的实现很让人惊艳:

void build()
{
    for(int i=0;i<M;i++)
    {
        if(t[0].ch[i])q.push(t[0].ch[i]);
    }
    while(!q.empty())
    {
        LL x=q.front();
        q.pop();
        for(int i=0;i<M;i++)
        {
            if(t[x].ch[i])
            {
                t[t[x].ch[i]].fail=t[t[x].fail].ch[i];
                q.push(t[x].ch[i]);
            }
            else t[x].ch[i]=t[t[x].fail].ch[i];
        }
    }
}

当然这一份是我按自己的马蜂写的。

显然让人比较茫然的地方是 t[x].ch[i]=t[t[x].fail].ch[i],我儿子怎么变成 Fail 指针的了?

唉?这样 Trie 树的形态不是被破坏了吗?唉唉唉?这连树都不是了啊?

其实这只是一个路径压缩,我们刚才说了,找 Fail 指针的时候是一个一个跳的,太慢了,改成这样好的多。

除了找 Fail 指针是路径压缩之外,你本来没有这个儿子,你就失配了就得跳,直到跳到一个地方之后,跳到这个儿子。

你现在直接连向这个儿子,也是一个路径压缩。

顺便说说 Fail 的性质:

仔细观察,易得 Fail 指针指向的那一项是我们这一项表示的字符串的后缀。

然后有一个比较重要的性质,每个点和 Fail 指针指向的点连边,构成一棵树(以下称为 Fail 树)。

下面我们把一个点在 Fail 树上到根节点的链叫做 Fail 链。

例题 1

多个模式串,一个文本串,求有几个模式串在文本串中出现过。

在每个节点的结束节点打标记,显然标记可以叠加,然后让文本串在 ACAM 上移动,每次移动到一个点,显然说明整个 Fail 链的节点表示的字符串都是有的(已知 Fail 指针指向的是自己的后缀),而且这样是不漏的。

由于每个节点只取一次标记,所以时间复杂度 \mathcal O(n)

#include<bits/stdc++.h>
#define LL long long
#define pb push_back
using namespace std;
const LL N=1e6+5;
const LL M=26;
struct trie
{
    LL ch[M],fail,val;
}t[N];
queue<LL>q;
LL n,fail[N],tot,ans,hav[N];
char c[N];
void ins(char *c)
{
    LL now=0,n=strlen(c+1);
    for(int i=1;i<=n;i++)
    {
        if(!t[now].ch[c[i]-'a'])t[now].ch[c[i]-'a']=++tot;
        now=t[now].ch[c[i]-'a'];
    }
    t[now].val++;
}
void build()
{
    for(int i=0;i<M;i++)
    {
        if(t[0].ch[i])q.push(t[0].ch[i]);
    }
    while(!q.empty())
    {
        LL x=q.front();
        q.pop();
        for(int i=0;i<M;i++)
        {
            if(t[x].ch[i])
            {
                t[t[x].ch[i]].fail=t[t[x].fail].ch[i];
                q.push(t[x].ch[i]);
            }
            else t[x].ch[i]=t[t[x].fail].ch[i];
        }
    }
}
void find(char *c)
{
    LL n=strlen(c+1),now=0;
    for(int i=1;i<=n;i++)
    {
        now=t[now].ch[c[i]-'a'];
        for(int j=now;j&&t[j].val!=-1;j=t[j].fail)
        {
            ans+=t[j].val;
            t[j].val=-1;
        }
    }
}
int main()
{
    scanf("%lld",&n);
    for(int i=1;i<=n;i++)
    {
        scanf("%s",c+1);
        ins(c);
    }
    build();
    scanf("%s",c+1);
    find(c);
    printf("%lld\n",ans);
}

例题 2

多个模式串,一个文本串,求每个字符串出现了几次。

我们统计每个结点表示的字符串出现了几遍,让文本串在 ACAM 上移动,每次移动到一个点,显然说明整个 Fail 链的节点在字符串的这个位置出现了一次,所以写个树上差分就行。

时间复杂度为 \mathcal O(n)

#include<bits/stdc++.h>
#define LL long long
#define pb push_back
#define pLL pair<LL,LL>
using namespace std;
const LL N=2e6+5;
const LL M=26;
const LL K=1e6+5;
struct trie
{
    LL ch[M],fail;
}t[K];
queue<LL>q;
vector<LL>v[N];
LL n,tot,cnt[N],endpos[N];
char c[N];
void ins(char *c,LL nam)
{
    LL now=0,n=strlen(c+1);
    for(int i=1;i<=n;i++)
    {
        if(!t[now].ch[c[i]-'a'])t[now].ch[c[i]-'a']=++tot;
        now=t[now].ch[c[i]-'a'];
    }
    endpos[nam]=now;
}
void build()
{
    for(int i=0;i<M;i++)
    {
        if(t[0].ch[i])q.push(t[0].ch[i]);
    }
    while(!q.empty())
    {
        LL x=q.front();
        q.pop();
        for(int i=0;i<M;i++)
        {
            if(t[x].ch[i])
            {
                t[t[x].ch[i]].fail=t[t[x].fail].ch[i];
                q.push(t[x].ch[i]);
            }
            else t[x].ch[i]=t[t[x].fail].ch[i];
        }
    }
}
void find(char *c)
{
    LL n=strlen(c+1),now=0;
    for(int i=1;i<=n;i++)
    {
        now=t[now].ch[c[i]-'a'];
        cnt[now]++;
    }
}
void dfs(LL x)
{
    for(LL i:v[x])
    {
        dfs(i);
        cnt[x]+=cnt[i];
    }
}
int main()
{
    //freopen("3.in","r",stdin);
    scanf("%lld",&n);
    for(int i=1;i<=n;i++)
    {
        scanf("%s",c+1);
        ins(c,i);
    }
    build();
    scanf("%s",c+1);
    find(c);
    for(int i=1;i<=tot;i++)
    {
        v[t[i].fail].pb(i);
    }
    dfs(0);
    for(int i=1;i<=n;i++)
    {
        printf("%lld\n",cnt[endpos[i]]);
    }
}

结语

理解建议去 OI Wiki 看一手图,或者看看你的 KMP 学习笔记。

我是一个先学 SAM 再学 SA 再学 AC 自动机的神秘的人。