CF653F Paper task

· · 题解

CF653F Paper task

来一发 SA 的题解((

首先,我们知道实际上此题就是要求本质不同的合法括号子串的数量。

既然如此,不妨先不考虑本质不同这个条件,关注合法括号子串这个问题。

我们令 ( 的权值是 1,) 的权值是 -1。令 d_j = \sum\limits_{i = 1}^j v_i,其中 v_i 表示第 i 个字符的权值。

那我们先列出合法括号串的条件:

  1. 在串内的任意一个位置,在它前面 ( 的数量大于等于 ) 的数量。
  2. 串首的字符是 (
  3. 整个串 ( 的数量等于 ) 的数量。

把这些条件转化可以得到:

  1. \forall i, d_i \ge d_1 - 1
  2. d_n = d_1 - 1

于是只需要在 d 上建立 ST表,并对每个位置进行二分就能得到一个 R_i 表示使得子串 s_{i\cdots r} 为合法子串的最远的 r。那怎么统计有多少个以 i 为左边界的合法括号子串呢?我们只需要求出 i \cdots R_i 里有多少个 d_j = d_i - 1 的点即可。这个问题可以给每个 d_i 的值开一个存下标的 vector 进行二分求得。

然后我们要考虑的就是本质不同这个限制,但实际上也很简单。考虑最简单的本质不同子串的做法,求出 SAheight 两个数组之后求 sa_i sa_{i - 1} 有多少最长公共合法括号前缀即可。这个可以用前面的方法轻松求得。最后用总数减去就行了。

下面是我的 Code:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll maxn = 5e5 + 5;
const ll maxv = 5e5;
// again
ll n;
char s[maxn];
ll sa[maxn << 1], rk[maxn << 1], cnt[maxn];
ll tmp[maxn], x[maxn], y[maxn];
void getsa(){
    memset(cnt, 0, sizeof(cnt));
    for(int i = 1;i <= n;i++) rk[i] = s[i], cnt[rk[i]]++;
    for(int i = 1;i <= maxv;i++) cnt[i] += cnt[i - 1];
    for(int i = n;i >= 1;i--) sa[cnt[rk[i]]--] = i;
    for(int w = 1;w <= n;w <<= 1){
        ll p = 0;
        for(int i = n;i >= n - w + 1;i--) tmp[++p] = i;
        for(int i = 1;i <= n;i++){
            if(sa[i] > w) tmp[++p] = sa[i] - w;
            x[i] = rk[i], y[i] = rk[i + w];
        }
        memset(cnt, 0, sizeof(cnt));
        for(int i = 1;i <= n;i++) cnt[rk[tmp[i]]]++;
        for(int i = 1;i <= maxv;i++) cnt[i] += cnt[i - 1];
        for(int i = n;i >= 1;i--) sa[cnt[rk[tmp[i]]]--] = tmp[i];
        ll num = 1; rk[sa[1]] = 1;
        for(int i = 2;i <= n;i++){
            rk[sa[i]] = num;
            if(x[sa[i]] != x[sa[i - 1]] || y[sa[i]] !=
            y[sa[i - 1]]){
                rk[sa[i]] = ++num;
            }
        }
    }
    return ;
}
ll hei[maxn];
void gethei(){
    ll lth = 0;
    for(int i = 1;i <= n;i++){
        if(lth) lth--;
        if(rk[i] > 1){
            while(i + lth <= n && sa[rk[i] - 1] + lth <= n 
            && s[i + lth] == s[sa[rk[i] - 1] + lth]) lth++;
        }
        hei[rk[i]] = lth;
    }
    return ;
}
ll d[maxn];
ll st[maxn][21], lg[maxn];
void getst(){
    for(int i = 1;i <= n;i++) st[i][0] = d[i];
    for(int i = 1;(1 << i) <= n;i++){
        for(int j = 1;j <= n - (1 << i) + 1;j++){
            st[j][i] = min(st[j][i - 1], st[j + (1 << (i - 1))][i - 1]);
        }
    }
    for(int i = 1;i <= n;i++){
        lg[i] = lg[i - 1] + ((1 << (lg[i - 1] + 1)) == i);
    }
    return ;
}
ll rmq(ll l, ll r){
    ll len = lg[r - l + 1];
    return min(st[l][len], st[r - (1 << len) + 1][len]);
}
vector<ll> v[maxn << 1];
ll que(ll l, ll r, ll val){
    val += maxv;
    ll L = lower_bound(v[val].begin(), v[val].end(), l) - 
    v[val].begin();
    ll R = upper_bound(v[val].begin(), v[val].end(), r) -
    v[val].begin() - 1;
    if(L > R || L == v[val].size()) return 0;
    return R - L + 1;
}
ll sid[maxn];
ll cntans;
int main(){
    scanf("%lld", &n);
    scanf("%s", s + 1);
    getsa(); gethei();
    for(int i = 1;i <= n;i++){
        if(s[i] == '('){
            d[i] = d[i - 1] + 1;
        } else {
            d[i] = d[i - 1] - 1;
            v[d[i] + maxv].push_back(i);
        }
    }
    getst();
    for(int i = 1;i <= n;i++){
        ll l = i, r = n, ans = 0;
        while(l <= r){
            ll mid = (l + r) >> 1;
            if(rmq(i, mid) < d[i] - 1){
                r = mid - 1;
            } else {
                ans = mid;
                l = mid + 1;
            }
        }
        sid[i] = ans;
        if(s[i] == '(') cntans += que(i, sid[i], d[i] -  1);
    }
    for(int i = 2;i <= n;i++){
        if(s[sa[i]] == ')') break;
        ll R = min(sa[i] + hei[i] - 1, sid[sa[i]]);
        cntans -= que(sa[i], R, d[sa[i]] - 1);
    }
    cout << cntans << endl;
    return 0;
}

上述算法的复杂度为 \mathcal O(n \log n),可以通过此题。