P3412 仓鼠找sugar II 题解

· · 题解

P3412 仓鼠找sugar II 题解

大水题一个

题目大意

给定你一个树,设 f_{u, v} 表示在树上随机游走的情况下从 u 走到 v 的期望步数,求 \displaystyle \frac{\sum_{i = 1}^n \sum_{j = 1}^n f_{i, j}}{n^2}

题解

不难想到 dp,不过 1e5 的范围差点让我怀疑我 O(n) 的 dp。先设一下状态,设 f_u 表示 u 子树内的所有点全都走到点 u 的期望步数。答案就是以每个点为根时根的 f 值的和。

考虑怎么转移。

似乎不好直接转,于是想想我们转移时什么东西卡住了我们。假设现在 u 子树内的所有点都走到了 u,那么我们现在想要让这些点再从 u 结点走向它的父亲结点,这个期望步数不好直接求。

于是我们再设 g_u 表示从 u 结点走到它的父亲结点的期望步数。先来考虑它的转移。deg_u 表示 u 结点的度,即与它相连的边数,son_u 表示 u 结点的儿子构成的集合。

g_u &= \frac{1}{deg_u} + \sum_{v \in son_u} \frac{1 + g_v + g_u}{deg_u}\\ deg_u \times g_u &= 1 + \sum_{v \in son_u} (1 + g_v + g_u)\\ &= 1 + (deg_u - 1) + (deg_u - 1) \times g_u + \sum_{v \in son_u} g_v\\ &= deg_u + (deg_u - 1) \times g_u + \sum_{v \in son_u} g_v\\ g_u &= deg_u + \sum_{v \in son_u} g_v \end{aligned} $$\begin{aligned} f_u &= \sum_{v \in son_u} f_v + size_v \times g_v \end{aligned}$$ 这个非常好理解。 于是可以打 $n^2$ 了。 换一下根,就 $O(n)$ 了。 设 $h(x)$ 为以 $x$ 为根时 $x$ 的 $f$ 值,那么有: $$\begin{aligned} h_u &= f_u + (h_{fa} - f_u - size_u \times g_u) + (n - size_u) \times (g_1 - g_u) \end{aligned}$$ 最终答案为 $\displaystyle \frac{\sum_i^n h_i}{n^2}$。 然后就没了。 ## 代码 ```cpp #include <bits/stdc++.h> #define int long long using namespace std; const int M = 100005; const int mod = 998244353; int n, f[M], g[M], siz[M], out[M], inv, ans, h[M]; int from[M << 1], to[M << 1], head[M], nex[M << 1], tot; inline void add_edge(int u, int v) { from[++ tot] = u; to[tot] = v; nex[tot] = head[u]; head[u] = tot; } void dfs1(int u, int fa) { g[u] = out[u]; siz[u] = 1; for(int i = head[u]; i; i = nex[i]) { int v = to[i]; if(v == fa) continue; dfs1(v, u); siz[u] += siz[v]; g[u] = (g[u] + g[v]) % mod; f[u] = (f[u] + f[v] + siz[v] * g[v] % mod) % mod; } } void dfs2(int u, int fa) { h[u] = (f[u] + (h[fa] - f[u] + mod - siz[u] * g[u] % mod + mod) % mod + (n - siz[u]) * ((g[1] - g[u] + mod) % mod) % mod) % mod; ans = (ans + h[u]) % mod; for(int i = head[u]; i; i = nex[i]) { int v = to[i]; if(v == fa) continue; dfs2(to[i], u); } } inline int quick_pow(int base, int ci, int pp) { int res = 1; while(ci) { if(ci & 1) res = res * base % pp; base = base * base % pp; ci >>= 1; } return res; } signed main() { ios::sync_with_stdio(0); cin.tie(0), cout.tie(0); cin >> n; inv = quick_pow(n, mod - 2, mod); for(int i = 1; i < n; ++ i) { int u, v; cin >> u >> v; add_edge(u, v); add_edge(v, u); ++ out[u]; ++ out[v]; } dfs1(1, 0); ans = f[1]; h[1] = f[1]; for(int i = head[1]; i; i = nex[i]) dfs2(to[i], 1); ans = ans * inv % mod * inv % mod; cout << ans; } ```