「Solution」P5555 秩序魔咒

· · 题解

\huge\color{skyblue}{\mathcal{My\ Blog}}

Description

给出两个字符串,求最长公共回文子串的长度和数量。

Solution

方法是 PAM,但是重点在于,我们不需要建两个 PAM 然后最后再跑 dfs(当然这样也可做)。

考虑 PAM 的每个结点,其代表的都是一个回文串,所以我们将两个字符串在同一棵 PAM 上分别跑一遍,那些在两次 solve 过程中都被访问过的结点就是两个字符串的公共回文子串。

于是我们用 flag_{0/1} 记录第 1/2solve 中被访问的结点,就能够找出两个字符串的所有公共回文子串,最后再找最大长度并计数即可。

AC\ Code

需要注意的是,这种做法 PAM 的空间复杂度是 O(n+m),空间要开到 600000

#include<bits/stdc++.h>
#define Int inline int
#define Void inline void
#define ri register int
using namespace std;
int n, m, mx, ans;
bool flag[2][600005];
LL read(){
    ll n = 0; int f = 1; char ch = getchar();
    while('0' > ch || ch > '9'){
        if(ch == '-') f *= -1;
        ch = getchar();
    }
    while('0' <= ch && ch <= '9'){
        n = (n << 1) + (n << 3) + ch-'0';
        ch = getchar();
    }
    return f * n;
}
Void input() {}
template<typename Type, typename... Types>
Void input(Type& arg, Types&... args){
    arg = read();
    input(args...);
}
namespace PAM{
    const int SIZE = 6e5;
    struct node{
        int num, len, fail, ch[26];
    }nd[SIZE];
    char s[SIZE];
    int cnt, last, n;
    Void init(){
        nd[0].len = 0, nd[0].fail = 1;
        nd[1].len = -1, nd[1].fail = 0;
        cnt = 1;
        last = 0;
    }
    Int getfail(int p, int k){
        while(k - nd[p].len - 1 < 0 || s[k-nd[p].len-1] != s[k]) p = nd[p].fail;
        return p;
    }
    Void solve(int T){
        for(ri i = 0; i < n; i++){
            int pos = getfail(last, i);
            if(!nd[pos].ch[s[i]-'a']){
                cnt++;
                nd[cnt].fail = nd[getfail(nd[pos].fail, i)].ch[s[i]-'a'];
                nd[cnt].len = nd[pos].len + 2;
                nd[cnt].num = nd[nd[cnt].fail].num + 1;
                nd[pos].ch[s[i]-'a'] = cnt;
            }
            last = nd[pos].ch[s[i]-'a'];
            flag[T][last] = true;
        }
    }
}
int main(){
    PAM::init();
    input(n, m);
    scanf("%s", PAM::s);
    PAM::n = n;
    PAM::solve(0);
    scanf("%s", PAM::s);
    PAM::n = m;
    PAM::solve(1);
    for(ri i = 0; i <= PAM::cnt; i++){
        if(flag[0][i] && flag[1][i]){
            if(PAM::nd[i].len > mx){
                mx = PAM::nd[i].len;
                ans = 1;
            }
            else if(PAM::nd[i].len == mx) ans++;
        }
    }
    printf("%d %d", mx, ans);
    return 0;
}