题解 P5287 【[HNOI2019]JOJO】

· · 题解

Description

两种操作

每次操作后,输出当前字符串的next数组只和。

操作数\leq 10^5c \leq 10^4

Solution

2操作显然可以用离线建树解决。

考虑如果一段前缀匹配一段后缀,那么除了第一段的字符,其他段的二元组(x,c)需要满足完全相等。这本质上可以将(x,c)视为一种特殊的字符进行匹配。

先考虑在当前串的末尾新加入一个二元组(x,c)怎么计算答案,显然是不断跳next,如果当前前缀后接的字符为c,那么可以增加一段首项为当前前缀长度,公差为1的等差数列的贡献(要与之前的贡献取max,也就是只能覆盖之前没覆盖的位置)。

而如果在next链上有恰好等于(x,c)的二元组,那么当前串结尾的next指向最靠后的二元组,否则指向0。因为如果存在一个二元组(x',c)满足x'>x,那么由于下次添加的字符一定与当前字符不同,所以一定无法匹配。

但是由于kmp的复杂度是均摊的,所以直接把上述算法套进树上dfs是行不通的。

考虑一个叫kmp自动机的东西,它的本质是把kmpnext的过程预处理,由于本题字符集大小很大,用主席树维护。

f_{i,j,k}表示在串的s_{i-1}位置添加一个字符(j,k)next所到达的位置,同理设g_{i,j,k}表示增加的答案。在dfs的时候,修改f_{i,x,c}的值,并将g_{i,x,1..c}设置为首项为当前串长度,公差为1的等差数列。dfs下一层前把f_{i+1}的状态由f_{next[i]+1}继承过来。

由于每次操作都是把数列设为一个公差为1的等差数列,所以只用维护首项所覆盖的位置,剩下的用高斯求和单独算即可。

#include <bits/stdc++.h>
using namespace std;

inline int gi()
{
    char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    int sum = 0;
    while ('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
    return sum;
}

inline char gc()
{
    char c = getchar();
    while (c < 'a' || c > 'z') c = getchar();
    return c;
}

typedef long long ll;
const int maxn = 100005, M = 1e4 + 7, mod = 998244353;

int n;
int val[maxn], pos[maxn], ans[maxn], a[maxn], b[maxn], top;
vector<int> to[maxn];

int rt[maxn][26], mx[maxn][26], tot;
struct seg
{
    int l, r, lch, rch, sum, lzy, nxt;
} t[maxn * 60];

#define mid ((l + r) >> 1)

inline void new_node(int &s) {t[++tot] = t[s]; s = tot;}
inline void add(int s, int v, int len) {t[s].sum = (ll)v * len % mod; t[s].lzy = v;}
inline void push_down(int s, int l, int r)
{
    if (!t[s].lzy) return ;
    new_node(t[s].lch); add(t[s].lch, t[s].lzy, mid - l + 1);
    new_node(t[s].rch); add(t[s].rch, t[s].lzy, r - mid);
    t[s].lzy = 0;
}

void modify(int &s, int l, int r, int x, int v, int p)
{
    new_node(s);
    if (r < x) return add(s, v, r - l + 1);
    if (l == r) return t[s].nxt = p, add(s, v, 1);
    push_down(s, l, r);
    modify(t[s].lch, l, mid, x, v, p);
    if (x > mid) modify(t[s].rch, mid + 1, r, x, v, p);
    t[s].sum = (t[t[s].lch].sum + t[t[s].rch].sum) % mod;
}

void query(int &s, int l, int r, int x, int &ans, int &nxt)
{
    if (r < x) return ans = (ans + t[s].sum) % mod, void();
    if (l == r) return ans = (ans + t[s].sum) % mod, nxt = t[s].nxt, void();
    push_down(s, l, r);
    query(t[s].lch, l, mid, x, ans, nxt);
    if (x > mid) query(t[s].rch, mid + 1, r, x, ans, nxt);
}

inline int getsum(int x) {return ((ll)x * (x + 1) >> 1) % mod;}

void dfs(int u)
{
    ++top;
    int x = val[u] / M, y = val[u] % M, nxt = 0;
    a[top] = val[u]; b[top] = b[top - 1] + y;
    if (top == 1) ans[u] = getsum(y - 1);
    else {
        ans[u] = (ans[u] + getsum(min(mx[top][x], y))) % mod;
        query(rt[top][x], 1, M, y, ans[u], nxt);
        if (!nxt && a[1] / M == x && b[1] < y) nxt = 1, ans[u] = (ans[u] + (ll)b[1] * max(0, y - mx[top][x])) % mod;
    }
    mx[top][x] = max(mx[top][x], y);
    modify(rt[top][x], 1, M, y, b[top - 1], top);
    for (int v : to[u]) {
        memcpy(mx[top + 1], mx[nxt + 1], sizeof(mx[top + 1]));
        memcpy(rt[top + 1], rt[nxt + 1], sizeof(rt[top + 1]));
        ans[v] = ans[u]; dfs(v);
    }
    --top;
}

int main()
{
    n = gi();
    for (int op, x, i = 1; i <= n; ++i) {
        op = gi(); x = gi();
        if (op == 1) val[++tot] = (gc() - 'a') * M + x, pos[i] = tot, to[pos[i - 1]].push_back(pos[i]);
        else pos[i] = pos[x];
    }

    for (int i : to[0]) {
        tot = 0;
        memset(rt[1], 0, sizeof(rt[1]));
        memset(mx[1], 0, sizeof(mx[1]));
        dfs(i);
    }

    for (int i = 1; i <= n; ++i) printf("%d\n", ans[pos[i]]);

    return 0;
}