P10800 卡牌 题解

· · 题解

首先考虑降维。

注意到若一个卡牌能满足条件,则严格大于其的卡牌都能满足条件,即满足条件的卡牌是连续的,那么可以设 f_{i, j, k} 表示当攻击为 i、防御为 j、速度为 k 时满足要求的卡牌数量。

当添加一张属性为 (a, b, c, d) 的卡牌时:

例如当 n = 3 时,添加一张 (1, 1, 2, 2) 的卡牌,此时 f 的值是这样的:

0 0 0 | 0 0 0 | 0 1 1
0 1 1 | 0 1 1 | 1 3 3
0 1 1 | 0 1 1 | 1 3 3

(其中每一行表示 i,每一列表示 j,每一层表示 k

发现我们做的操作其实是:

下文中我们记 a_i 为每一行取到的最小值,b_i 为每一列取到的最小值,c_i 为每一层取到的最小值。只需要从后往前推一遍即可 O(n) 求出这些了。

然后考虑如何统计答案。

实际上,我们可以对于每一层分别求和,同时维护每一列的和。

我们用四个数据结构维护以下四个信息:

那么维护过程大概是这样的:首先枚举每一层,处理被取消置零的行和不被 c_i 覆盖的行,同时用数据结构维护以上四个信息,得到每一列的和后考虑 b_jc_i 覆盖了哪些列,然后去掉被置零的列即可。

我们用 \log 数据结构维护以上四个信息,那么总复杂度是 O(m + n \log n) 的。

具体维护过程见代码。

#include <algorithm>
#include <iostream>
using namespace std;
using u32 = unsigned;
const int N = 5e5 + 5;
int n, m, a[N], b[N], c[N], x[N], y[N], z[N], f[N], g[N], h[N];
u32 ans;
// 线段树维护区间带权加、区间求和
struct segment_tree {
    u32 s[N], t[N << 2], lzy[N << 2];
    void pushup(int rt) { t[rt] = t[rt << 1] + t[rt << 1 | 1]; }
    void push(int rt, int l, int r, u32 k) { t[rt] += (s[r] - s[l - 1]) * k, lzy[rt] += k; }
    void pushdown(int rt, int l, int r) {
        if (!lzy[rt]) return;
        int mid = (l + r) >> 1;
        push(rt << 1, l, mid, lzy[rt]);
        push(rt << 1 | 1, mid + 1, r, lzy[rt]);
        lzy[rt] = 0;
    }
    void update(int rt, int l, int r, int x, int y, u32 k) {
        if (x <= l && r <= y) return push(rt, l, r, k);
        pushdown(rt, l, r);
        int mid = (l + r) >> 1;
        if (x <= mid) update(rt << 1, l, mid, x, y, k);
        if (y >= mid + 1) update(rt << 1 | 1, mid + 1, r, x, y, k);
        pushup(rt);
    }
    u32 query(int rt, int l, int r, int x, int y) {
        if (x <= l && r <= y) return t[rt];
        pushdown(rt, l, r);
        int mid = (l + r) >> 1;
        u32 ret = 0;
        if (x <= mid) ret += query(rt << 1, l, mid, x, y);
        if (y >= mid + 1) ret += query(rt << 1 | 1, mid + 1, r, x, y);
        return ret;
    }
    void update(int l, int r, u32 k) { update(1, 1, n, l, r, k); }
    u32 query(int l, int r) { return query(1, 1, n, l, r); }
} t1;
// 树状数组维护区间加、区间求和
struct fenwick_tree {
    u32 t1[N], t2[N];
    void update(int x, u32 k) {
        for (int i = x; i <= n; i += i & -i) t1[i] += k, t2[i] += k * x;
    }
    u32 query(int x) {
        u32 ret = 0;
        for (int i = x; i >= 1; i -= i & -i) ret += (x + 1) * t1[i] - t2[i];
        return ret;
    }
    void update(int l, int r, u32 k) { update(l, k), update(r + 1, -k); }
    u32 query(int l, int r) { return query(r) - query(l - 1); }
} t2, t3, t4;
int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i <= n; i++) a[i] = b[i] = c[i] = n;
    for (int i = 1, ai, bi, ci, di; i <= m; i++) {
        cin >> ai >> bi >> ci >> di, di = n - di;
        a[ai] = min(a[ai], di); // 每一行最小值 ai
        b[bi] = min(b[bi], di); // 每一列最小值 bi
        c[ci] = min(c[ci], di); // 每一层最小值 ci
        x[ai] = max(x[ai], bi); // 每一行置零的列数 xi
        y[ci] = max(y[ci], ai); // 每一层置零的行数 yi
        z[ci] = max(z[ci], bi); // 每一层置零的列数 zi
    }
    for (int i = n - 1; i >= 1; i--) {
        a[i] = min(a[i], a[i + 1]), b[i] = min(b[i], b[i + 1]), c[i] = min(c[i], c[i + 1]);
        x[i] = max(x[i], x[i + 1]), y[i] = max(y[i], y[i + 1]), z[i] = max(z[i], z[i + 1]);
    }
    y[0] = n, g[0] = 1; // 处理一些边界
    for (int i = 1, j = 1, k = 1, l = 1; i <= n; i++) {
        while (j <= n && b[j] <= a[i]) ++j; // fi 为第一个 bj > ai 的列
        while (k <= n && a[k] <= c[i]) ++k; // gi 为第一个 aj > ci 的行
        while (l <= n && b[l] <= c[i]) ++l; // hi 为第一个 bj > ci 的列
        f[i] = j, g[i] = k, h[i] = l;
    }
    for (int i = 1; i <= n; i++) t1.s[i] = t1.s[i - 1] + b[i];
    for (int i = 1; i <= n; i++) {
        // 处理被取消置零的行
        for (int j = y[i - 1]; j > y[i]; j--) {
            // 可以被 bj 覆盖的列区间 (xj, fj)
            // 可以被 ci 覆盖的列区间 (xj, n]
            int l = x[j] + 1, r = max(l - 1, f[j] - 1);
            t1.update(l, r, 1);                      // 更新 t1,注意使用的数据结构
            t2.update(r + 1, n, a[j]);               // 更新 t2
            if (j < g[i - 1]) t3.update(l, n, a[j]); // 若没有被 ci 覆盖则更新 t3
            else t4.update(l, n, 1);                 // 否则更新 t4
        }
        // 处理不被 ci 覆盖的行
        for (int j = g[i - 1]; j < g[i]; j++) {
            int l = x[j] + 1;
            if (j > y[i]) t3.update(l, n, a[j]), t4.update(l, n, -1); // 若没有被置零则更新
        }
        // 可以被 bj 覆盖的列区间 (zi, hi)
        int l = z[i] + 1, r = max(l - 1, h[i] - 1);
        // 对于 bj < ci 的列,用 bj 覆盖肯定更优,否则用 ci 覆盖
        ans += t1.query(l, r) + t2.query(l, r) + t3.query(r + 1, n) + t4.query(r + 1, n) * c[i];
    }
    cout << ans;
}