[ABC343G] Compress Strings题解

· · 题解

题目大意

给定 N (1 \leq N \leq 20) 个字符串,求出一个最短的字符串,使得这个串包含所有给定的 N 个串。

solution

由于 1 \leq N \leq 20 ,所以考虑状压。

可以先把被其他字符串包含的字符串去除,再预处理出 p_{i,j} 表示字符串 i 的后缀与字符串 j 的前缀的最长相同长度。

f_{i,S} 表示已选的字符串的集合为 S,最后一个选的字符串是 i。那么若要从当前状态 i 转移到新状态 j,则必须满足串 i 在集合 S 里,串 j 不在集合 S 里。若满足以上条件,则有转移:

f_{j,S|2^{j-1}}=\min(f_{j,S|2^{j-1}},f_{i,S}+len_{j}-p_{i,j})

其中 len_{i} 表示字符串 i 的长度。

初始值全为 \inff_{i,2^{i-1}}=len_{i}

code:

#include<cstdio>
#include<string>
#include<cstring>
#include<iostream>
#include<algorithm>
#define N 22
#define L 400001
using namespace std;
int f[N][1<<N],n,p[N][N],pi[L],len[N],ans=2147483647;
string s[N],ss[N];
bool bz[N];
int pre(string s1,string s2){
    string s=s1+s2;
    int l=s.size();
    pi[0]=0;
    for(int i=1;i<l;i++){
        int j=pi[i-1];
        while(j>0&&s[i]!=s[j])j=pi[j-1];
        if(s[i]==s[j])j++;
        pi[i]=j;
    }
    return pi[l-1];
}
int pre1(string s1,string s2){
    string s=s1+s2;
    int l=s.size();
    pi[0]=0;
    for(int i=1;i<l;i++){
        int j=pi[i-1];
        while(j>0&&s[i]!=s[j])j=pi[j-1];
        if(s[i]==s[j])j++;
        pi[i]=j;
        if(i>=2*s1.size()-1&&pi[i]==s1.size())return s1.size();
    }
    return 0;
}
int main(){
    scanf("%d",&n);
    for(int i=1;i<=n;i++)cin>>ss[i];
    sort(ss+1,ss+n+1);
    n=unique(ss+1,ss+n+1)-ss-1;
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            if(pre1(ss[j],ss[i])&&i!=j)bz[j]=1;
    int u=0;
    for(int i=1;i<=n;i++)
        if(!bz[i])s[++u]=ss[i];
    n=u;
    for(int i=1;i<=n;i++)len[i]=s[i].size();
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            if(i!=j){
                //i:first j:second
                //ji
                string g=s[j]+s[i];
                p[i][j]=pre(s[j],s[i]);
                if(p[i][j]==0)p[i][j]=pre1(s[j],s[i]);
            }
        }
    }
    memset(f,127,sizeof(f))
    for(int i=1;i<=n;i++)f[i][1<<(i-1)]=len[i];
    for(int i=0;i<1<<n;i++)
        for(int j=1;j<=n;j++)
            for(int l=1;l<=n;l++)
                if(j!=l&&(i&(1<<(j-1)))&&(!(i&(1<<(l-1)))))
                    f[l][i|(1<<(l-1))]=min(f[j][i]+len[l]-p[j][l],f[l][i|(1<<(l-1))]);
    for(int i=1;i<=n;i++)ans=min(ans,f[i][(1<<n)-1]);
    printf("%d",ans);
}