P5405 [CTS2019] 氪金手游 题解

· · 题解

好题。

题目给的奇怪性质其实就是 (u_i,v_i) 作为无向边后是一棵树。不妨令 1 为根。

先考虑原题的一个弱化:如果 (u_i,v_i) 作为有向边后是一棵外向树怎么做?

考虑以 x 为根的子树,我们需要 T_x 小于其子树中所有点(不含 x)的 T。假设 W 已经确定,那么枚举 T_x,可以得到 x 这个位置合法的概率为:

\begin{aligned} &\frac{W_x}{\sum W}\sum_{t=0}^{\infty}(\frac{\sum_{i\notin\mathrm{subtree}(x)} W_i}{\sum W})^t\\ =&\frac{W_x}{\sum W}\cdot\frac{1}{1-\frac{\sum_{i\notin\mathrm{subtree}(x)} W_i}{\sum W}}\\ =&\frac{W_x}{\sum W}\cdot\frac{\sum W}{\sum_{i\in\mathrm{subtree}(x)} W_i}\\ =&\frac{W_x}{\sum_{i\in\mathrm{subtree}(x)} W_i}\\ \end{aligned}

这个式子只和以 x 为根的这棵子树有关,因此考虑 dp,设 f_{x,s} 表示已经考虑了以 x 为根的子树,且 \sum\limits_{i\in\mathrm{subtree}(x)} W_i=s 时的方案合法概率。使用树形背包的方式合并子树即可,时间复杂度 \mathcal{O}(n^2)

考虑原问题,其中的内向边让我们很头疼,考虑把它们容斥掉。钦定哪些内向边不合法,则被钦定的内向边相当于反向变成了外向边,而我们不在乎未被钦定的内向边合不合法,所以可以直接删掉这些边。

经过这一处理后原树变成了一个外向树森林,因此同样可以使用上面的式子计算出合法概率。于是同样的进行树形 dp,对一条内向边额外决策是否钦定即可,记得钦定时要带一个 -1 的容斥系数。

#include <algorithm>
#include <iostream>
#include <vector>

using namespace std;
using LL = long long;

const int kN = 1001;
const int kM = 998244353;

int n, a[kN][3], iv[kN * 3], f[kN][kN * 3], g[kN * 3], s[kN];
vector<pair<int, bool>> e[kN];

int inv(int x, int m) { return x == 1 ? 1 : m - 1LL * inv(m % x, x) * m / x; }
void dfs(int x, int _f) {
  f[x][0] = 1;
  for (auto [y, t] : e[x]) {
    if (y == _f) {
      continue;
    }
    dfs(y, x);
    for (int i = 0; i <= s[x] * 3; ++i) {
      for (int j = 0; j <= s[y] * 3; ++j) {
        int p = 1LL * f[x][i] * f[y][j] % kM;
        if (t) {
          g[i] = (g[i] + p) % kM, g[i + j] = (g[i + j] - p + kM) % kM;
        } else {
          g[i + j] = (g[i + j] + p) % kM;
        }
      }
    }
    s[x] += s[y];
    copy_n(g, s[x] * 3 + 1, f[x]), fill_n(g, s[x] * 3 + 1, 0);
  }
  for (int i = s[x] * 3; i >= 0; --i) {
    for (int j = 1; j <= 3; ++j) {
      f[x][i + j] = (f[x][i + j] + 1LL * f[x][i] * a[x][j - 1] % kM * j * iv[i + j] % kM) % kM;
    }
    f[x][i] = 0;
  }
  ++s[x];
}

int main() {
  ios::sync_with_stdio(0), cin.tie(0);
  cin >> n;
  for (int i = 1; i <= n; ++i) {
    cin >> a[i][0] >> a[i][1] >> a[i][2];
    LL _iv = inv(a[i][0] + a[i][1] + a[i][2], kM);
    a[i][0] = a[i][0] * _iv % kM;
    a[i][1] = a[i][1] * _iv % kM;
    a[i][2] = a[i][2] * _iv % kM;
  }
  iv[1] = 1;
  for (int i = 2; i <= n * 3; ++i) {
    iv[i] = 1LL * (kM - kM / i) * iv[kM % i] % kM;
  }
  for (int i = 1, x, y; i < n; ++i) {
    cin >> x >> y;
    e[x].emplace_back(y, 0);
    e[y].emplace_back(x, 1);
  }
  dfs(1, 0);
  int ans = 0;
  for (int i = 1; i <= n * 3; ++i) {
    ans = (ans + f[1][i]) % kM;
  }
  cout << ans;
  return 0;
}