【ARC099D】 Eating Symbols Hard 题解

pldzy

2022-07-16 12:10:00

题解

AtC 传送门:F - Eating Symbols Hard

双哈希 + 差分

Solution

1

发现要想确定 [l,r] 对序列 A 的影响,必须得记录下 S 中的所有子串对 A 操作后的状态。进一步地,为了更好处理每一个 [l,r],自然会想到使用差分的方法——记录下 [1,i]\ (i \in [1,\left\vert S \right\vert])

但是上述想法暂时看起来仍有许多缺陷。注意到,我们的目的是确定 A 序列在不同情况下的状态,所以这就引导我们要去使用哈希甚至双哈希进行维护或记录。

为什么哈希能实现上述的差分呢?其实,稍加思考即可发现,即 h(i) 表示按照 S 的子序列 [1,i] 对初始为 0 的序列 A 进行操作后 A 序列的哈希值。那么 h(r)-h(l-1) 实际上就是将前 i-1 个操作造成的影响抵消掉。因为真正对序列 A 造成影响的操作就是对它元素进行加减,而对元素的加减操作显然是可以通过相减抵消的。所以最终只留下第 lr 个操作对序列的影响。

2

回过头来,题目要求统计的合法区间 [l,r] 貌似就是满足 h(r)-h(l-1)=h(n) 的区间。但是,这里忽略掉了一点,就是在 h(r)-h(l-1) 后,这个差值相比 h(n) 实质上会发生一个base 进制上的位移。简单解释一下,统计 [1,n] 区间的哈希值时,下标 p 是从 0 开始的,第一个操作乘上的进制是 base^0,然而,统计 [l,r] 区间的哈希值时,下标 p 经过了前 l-1 次操作,不妨记 P_i 表示第 i 次操作时 p 的下标,那么统计区间 [l,r] 时第一次乘上进制的就是 base^{P_{l}}

所以,为了消除后者比前者多出的从第 0 位到第 l - 1 位的位移,正确的柿子是 \dfrac{h(r)-h(l-1)}{base^{P_{l-1}}} = h(n)

3

知道了如何判断某一子区间是否合法之后,再考虑如何将它放进代码里实现。

看到时间限制和 \left\vert S \right\vert 的大小限制,发现统计时似乎只允许 \mathcal{\text{O}}(n) 的复杂度。这就引导我们想到一个大概的统计方法:枚举 l 遍历 1\left\vert S \right\vert,然后统计合法的区间右端点 r 的数量。进一步地,这样一来 h(n)h(l-1) 已知,根据推出的柿子就发现 h(r) 已经被唯一确定了,即 h(r)=h(n)\times base^{P_{l-1}} + h(l-1),所以一开始计算 h(i) 的时候开一个 map 把它们按照哈希值统计数量即可。

至此,此题被完美解决, 复杂度可过。

4

还要再提一个代码实现的时候需要注意的点。题目中已述,p 有可能为负数。那么此时乘上的进制就成了 base^{-\left\vert p \right\vert},然后简单推一波柿子:base^{-\left\vert p \right\vert}=(base^{-1})^{\left\vert p\right\vert}。而 base^{-1} 是什么?base 在取模某个模数意义下的逆元。因为模数我们一定会设为一个素数,故直接用快速幂算出其逆元即可。

需要双哈希。

Code

#include<bits/stdc++.h>
using namespace std;

#define rep(i, a, b) for(register int i = a; i <= b; ++i)
typedef long long ll;
const int maxn = 250005;
const int mod1 = 1e9 + 7, mod2 = 998244353;
int n, p[maxn], bs = 121;
char c[maxn];
map<pair<int, int>, int> mp;
pair<int, int> hsh[maxn];
ll ans;

inline int pw(int x, int p, int mod){
    int res = 1;
    while(p){
        if(p & 1) 
            res = 1ll * res * x % mod;
        p /= 2, x = 1ll * x * x % mod;
    } return res;
}
inline pair<int, int> gh(int x, int y){
    int a, b; a = b = x;
    if(y < 0) 
        a = pw(a, mod1 - 2, mod1), b = pw(b, mod2 - 2, mod2), y *= -1; 
    return make_pair(pw(a, y, mod1), pw(b, y, mod2));
}

inline pair<int, int> add(pair<int, int> a, pair<int, int> b){
    return make_pair((a.first + b.first) % mod1, (a.second + b.second) % mod2);
}
inline pair<int, int> mns(pair<int, int> a, pair<int, int> b){
    return make_pair((a.first + mod1 - b.first) % mod1, (a.second + mod2 - b.second) % mod2);
}
inline pair<int, int> tim(pair<int, int> a, pair<int, int> b){
    return make_pair((1ll * a.first * b.first) % mod1, (1ll * a.second * b.second) % mod2);
}

int main(){
    scanf("%d", &n); 
    rep(i, 1, n){ 
        char c; cin >> c; 
        p[i] = p[i - 1];
        if(c == '+') hsh[i] = add(hsh[i - 1], gh(bs, p[i]));
        if(c == '-') hsh[i] = mns(hsh[i - 1], gh(bs, p[i]));
        if(c == '<') hsh[i] = hsh[i - 1], p[i] -= 1;
        if(c == '>') hsh[i] = hsh[i - 1], p[i] += 1;
        mp[hsh[i]] += 1;
    }
    rep(i, 1, n)
        ans += mp[add(hsh[i - 1], tim(hsh[n], gh(bs, p[i - 1])))], 
        mp[hsh[i]] -= 1;
    printf("%lld\n", ans);
    return 0;
}

感谢阅读。