题解 AT2370 【[AGC013D] Piling Up】

· · 题解

在博客园食用更佳:https://www.cnblogs.com/PinkRabbit/p/AGC013.html。

我们假设初始时红砖的数量为 x_0,则蓝砖的数量应为 N - x_0

同时,我们记第 i 次操作后的红砖的数量为 x_i

观察第 i + 1 次操作时,可能执行的四种情况:

  1. 操作 RR:先拿出一块红砖,然后塞入一红一蓝两砖,再拿出一块红砖。
    此操作仅能在 x_i > 0 时被执行,并且 x_{i + 1} = x_i - 1\boldsymbol{x} 减少 \boldsymbol{1})。
  2. 操作 RB:先拿出一块红砖,然后塞入一红一蓝两砖,再拿出一块蓝砖。
    此操作仅能在 x_i > 0 时被执行,并且 x_{i + 1} = x_i不变)。
  3. 操作 BR:先拿出一块蓝砖,然后塞入一红一蓝两砖,再拿出一块红砖。
    此操作仅能在 x_i < N 时被执行,并且 x_{i + 1} = x_i不变)。
  4. 操作 BB:先拿出一块蓝砖,然后塞入一红一蓝两砖,再拿出一块蓝砖。
    此操作仅能在 x_i < N 时被执行,并且 x_{i + 1} = x_i + 1\boldsymbol{x} 增加 \boldsymbol{1})。

对于一种操作序列,根据上面的 x_i 关于 x_0 的表达式以及不等式,我们可以确定出 x_0 对应的区间(或为空集,此时序列不合法)。

我们考虑直接进行 DP:令 \mathrm{f}[i][j] 表示 x_i = j 时前 i 次操作的方案数,初始化 \mathrm{f}[0][\ast] = 1,答案就是 \mathrm{f}[N][\ast] 之和,转移显然。

错误!这样计算会使得一个合法的操作序列被统计多次,也就是上面的 x_0 可能的区间内至少包含两个整数的情况。

我们不妨考虑必须在 x_0 可能到达的最小值处统计一个操作序列。这意味着 x_0 哪怕再减小 1,操作序列都将变得不合法。

体现在 DP 过程中,即是存在一个时刻 j 值曾经到达过 0,或者为 1 时执行了 RB 操作。

所以我们更换状态定义,令 \mathrm{f}[i][j] 表示 x_i = j 时前 i 次操作的方案数,限制是 j 从未到达过最小值。

同时定义 \mathrm{g}[i][j] 表示类似含义,但限制是 j 到达过至少一次最小值。

转移仍然显然,答案即为 \mathrm{g}[N][\ast] 之和。时间复杂度为 \mathcal O (N M)

#include <cstdio>

typedef long long LL;
const int Mod = 1000000007;
const int MN = 3005;

int N, M;
int f[2][MN], g[2][MN], Ans;

int main() {
    scanf("%d%d", &N, &M);
    int o = 0;
    for (int i = 1; i <= N; ++i) f[o][i] = 1;
    g[o][0] = 1;
    for (int i = 1; i <= M; ++i) {
        o ^= 1;
        for (int j = 0; j <= N; ++j) f[o][j] = g[o][j] = 0;
        for (int j = 0; j <= N; ++j) {
            if (j) {
                f[o][j] = (f[o][j] + f[o ^ 1][j - 1]) % Mod;
                g[o][j] = ((LL)g[o][j] + g[o ^ 1][j - 1] + g[o ^ 1][j]) % Mod;
                if (j == 1) g[o][j] = (g[o][j] + f[o ^ 1][j]) % Mod;
                else f[o][j] = (f[o][j] + f[o ^ 1][j]) % Mod;
            }
            if (j < N) {
                g[o][j] = ((LL)g[o][j] + g[o ^ 1][j + 1] + g[o ^ 1][j]) % Mod;
                if (!j) g[o][j] = (g[o][j] + f[o ^ 1][j + 1]) % Mod;
                else f[o][j] = ((LL)f[o][j] + f[o ^ 1][j + 1] + f[o ^ 1][j]) % Mod;
            }
        }
    }
    for (int j = 0; j <= N; ++j) Ans = (Ans + g[o][j]) % Mod;
    printf("%d\n", Ans);
    return 0;
}