【题解】P9753 [CSP-S 2023] 消消乐(字符串哈希,DP)

· · 题解

【题解】P9753 [CSP-S 2023] 消消乐

不知道考场脑子是抽了还是有病,全程都不知道在放什么屁。

特别鸣谢:@dbxxx 给我讲解了解法一的满分做法,并让我对哈希有了更加深刻的认识;@Daidly 给我讲解了解法二。

博客园食用效果更佳

题目链接

P9753 [CSP-S 2023] 消消乐

题意概述

给定一个长度为 n 的只含小写字母的字符串,每次可以选择相邻两个字母消除,消除后前后未被消除的序列会自动拼接到一起。一个字符串是“可消除的”当且仅当对该串进行若干次上述操作后,可以变成一个空字符串。求给定的字符串有多少个子串是可消除的。

数据范围

对于所有测试数据有:1 \le n \le 2 \times 10^6,且询问的字符串仅由小写字母构成。

测试点 n\leq 特殊性质
1\sim 5 10
6\sim 7 800
8\sim 10 8000
11\sim 12 2\times 10^5 A
13\sim 14 2\times 10^5 B
15\sim 17 2\times 10^5
18\sim 20 2\times 10^6

特殊性质 A:字符串中的每个字符独立等概率地从字符集中选择。

特殊性质 B:字符串仅由 ab 构成。

思路分析

解法一

首先考虑一个字符串什么时候是“可消除的”,我们可以考虑类似于括号匹配的办法:

维护一个栈,按顺序遍历字符串,若当前字符恰好等于栈顶,则弹出栈顶;反之则将该字符入栈。若遍历结束后,栈为空,则说明该字符串是“可消除的“。

题目要求对于一个字符串所有的子串是否为“可消除的”,那么最暴力的想法就是暴力枚举该字符串的每个子串 [l,r] 并做上述括号匹配来判断,若该子串是“可消除的”,则方案数 +1

时间复杂度 O(n^3),期望得分 35pts。

考虑优化。注意到对于起点 l 相同的子串,只要维护过程中栈为空,那么就说明从 l 到当前为止是个“可消除的”的字符串。所以我们不需要每次遍历子串中的每一个字符,维护多个栈,只需要遍历一次 ln,维护一个栈来解决,具体地:

遍历 1n 的每一个数作为子串的起点 l,每次维护一个栈,遍历从 ln 的所有字符,做一遍括号匹配,同时在过程中维护栈被弹为空的次数 cnt,每次让答案加上 cnt 即可。

时间复杂度 O(n^2),期望得分 50pts。

考虑特殊性质 A,发现在随机序列下,符合条件的子串非常短,那么我们只需要选择区间长度较小的子串进行验证,就可以在题目要求的时间内过掉这两个点。

该特殊性质加上 O(n^2) 的做法,期望得分 60pts。

O(n^2) 做法的启发,考虑如何减少枚举次数。

发现 O(n^2) 做法相较于 O(n^3) 做法优化在于,只需要枚举子串起点,不需要枚举子串终点,就可以通过维护一次栈来求出以 l 起点的所有方案。

维护栈似乎是无法优化的,那么考虑我们如何做可以不用枚举子串起点。

发现在从 1n 维护栈序列的时候,若对于某个时刻 l 和某个时刻 r,两种时刻的栈序列完全相同,那么说明子串 [l+1,r] 一定是可消除的。

那么我们可以通过字符串哈希来维护每个时刻的栈序列,那么栈序列相同说明该情况下哈希值完全相同,可以用 mapunordered_map 来维护每种哈希值出现了多少次。假设一种哈希值出现了 k 次,那么其对答案的贡献就是 \mathrm C_k^2 = \dfrac{k\times (k-1)}{2}。即 k 个相同的时刻,每次取两个时刻 lr 构成的子串 [l+1,r] 是“可消除的”。

对于每种哈希值对答案的贡献求和,即为最终答案。

时间复杂度:map 维护 O(n \log n)unordered_map 维护 O(n)

注意:若采用单模数哈希,模数如果为 99824435310^9+7,相当于是 2\times 10^6 个数要落在大约是 [0,10^9] 这个区间,产生哈希冲突的可能性较大。所以最好用双模数哈希/自然溢出。双模数哈希相当于随机范围是两个模数相乘,自然溢出相当于是 [0,2^{64}],产生哈希冲突的可能性较小。由于自然溢出更好写,我采用的是自然溢出。

解法二

上一种解法,我们其实主要是站在「如何消」的角度展开思考的,下来我们来站在「区间合法性」的角度来思考这个问题。

说明:以下分析用 s_i 表示给定字符串的第 i 个字符。

考虑 DP,我们用 dp_i 表示以 i 为后缀的合法区间个数。

考虑 dp_i 能由谁转移过来。枚举 j<i,那么假如 dp_i 可以由 dp_j 更新,一定说明区间 [j+1,i] 是“可消除的”。要保证不重不漏,那么 j 一定是 i 前面满足 [j+1,i] 是可消除的最大的位置。

要使得 j 最大,那么 [j+1,i] 一定是以 i 结尾最短的合法区间,令 lst_i=j+1,那么 [lst_i,i] 就是以 i 结尾最短的”可消除区间“。所以对于每个 i,有:

dp_i=dp_{lst_i-1}+1

现在考虑如何找 lst_i

我们假设 dp_i 已经求出,考虑枚举到 i+1 时,dp_{i+1} 应该怎么求。对于 s_{i+1},我们有以下两种情况:

考虑复杂度分析。

定义「等价类」表示解法一(哈希做法)中同一个前缀栈哈希值。那么我们发现我们往前跳 lst 一定是在这个哈希值内部跳 lst

假如我们当前在某个等价类的最后一个字符,那么这个等价类就可以构成一个字符串。然后要在当前为止后面再加一个字符 a,那么实际上就是要找当前等价类字符串里的最后一个 a,那么时间复杂度就是等价类字符串里两个 a 的下标差。

即对于单个字符 a,查找 lst 的时间复杂度就是该等价类字符串里,最后一个 a 的下标减去倒数第二个下标加上倒数第二个下标减去倒数第三个下标,以此类推,线性相当于等价类大小。由于等价类大小总和为线性,那么单个字符的复杂度线性,所以总复杂度就是 O(n|\Sigma|),即 O(26n)

代码实现

解法一

//B
//哈希做法
//The Way to The Terminal Station……
#include<cstdio>
#include<iostream>
#include<map>
#include<stack>
#include<set>
#define int unsigned long long 
const int P=29;
using namespace std;
const int maxn=2e6+10;
int has[maxn],ans;

map<int,int>mp;
stack<int>q,pos;
set<int>S;

inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
    return x*f;
}

signed main()
{
    int n=read();
    string s;
    cin>>s;s='%'+s;
    S.insert(0);
    mp[0]++;
    for(int i=1;i<=n;i++)
    {
        if(!q.empty()&&q.top()==s[i]-'a')
        {
            mp[has[pos.top()-1]]++;
            has[i]=has[pos.top()-1];
            q.pop();
            pos.pop();
        }
        else
        {
            q.push(s[i]-'a');
            pos.push(i);
            has[i]=has[i-1]*P+s[i]-'a'+1;
            S.insert(has[i]);
            mp[has[i]]++;
        }
    }
    for(int v:S)ans+=mp[v]*(mp[v]-1)/2;
    cout<<ans<<'\n';
    return 0;
}

解法二

//B
//The Way to The Terminal Station…
#include<cstdio>
#include<iostream>
#define int long long
using namespace std;
const int maxn=2e6+10;
int dp[maxn],lst[maxn];

inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
    return x*f;
}

signed main()
{
    int n=read();
    string s;
    cin>>s;s='%'+s;
    int ans=0;
    for(int i=1;i<=n;i++)
    {
        for(int j=i-1;j>0;j=lst[j]-1)
        {
            if(s[i]==s[j])
            {
                lst[i]=j;break;
            }
        }
        if(lst[i])dp[i]=dp[lst[i]-1]+1,ans+=dp[i];
    }
    cout<<ans<<'\n';
    return 0;
}

如果觉得我的题解写得好,请给我的题解点个赞,谢谢!