题解 P5666 【树的重心】

xht

2019-11-30 23:08:35

题解

考虑每个点作为重心的次数。

首先拿出一个重心来当作根 rt

对于点 x \ne rt,如果 x 为割掉某条边后的重心,那么这条边一定不在 x 的子树内。

设割掉一条边后,另外一棵树的大小为 S

g_x = \max_{y \in \text{son}(x)} \{s_y\}

由于 x 为重心,所以要满足:

2 \times (n - S - s_x) \le n - S 2 \times g_x \le n - S

即:

n - 2s_x \le S \le n - 2g_x

即我们要找到对于一个 x,有多少条可以割掉的边满足:

  1. 边不在 x 的子树内。

如果只有第一个条件,那么我们可以进行一次 dfs,同时拿一个树状数组动态维护割掉当前点和其父亲之间的边后每一个 S 的值有多少个,那么对于每一个点的询问实质上就是一个区间求和。

然后我们要去掉不满足第二个条件,即在 x 的子树内的边,可以再拿一个树状数组按照 dfs 序动态维护经过的所有点的每一个 s 的值有多少个,那么对于一个点,在其子树内可以割掉的边就是回溯与进入时区间求和后的差。

最后还有一个问题是,当 x = rt,我们如何统计 x 为重心的次数?

x 的儿子中子树最大的节点为 u,次大的节点为 v

若割掉的边在 u 的子树中,则需要满足:

2 \times s_v \le n - S

即:

S \le n - 2s_v

否则,需要满足:

2 \times s_u \le n - S

即:

S \le n - 2s_u

可以在 dfs 的时候直接维护。

那么总时间复杂度为 \mathcal O(n \log n)

const int N = 3e5 + 7;
int n, rt, s[N], g[N], u, v, z[N];
vi e[N];
ll ans, c1[N], c2[N];

inline void add(ll *c, int x, int k) {
    ++x;
    while (x <= n + 1) c[x] += k, x += x & -x;
}

inline ll ask(ll *c, int x) {
    ++x;
    ll k = 0;
    while (x) k += c[x], x -= x & -x;
    return k;
}

void dfs1(int x, int f) {
    s[x] = 1, g[x] = 0;
    bool fg = 1;
    for (ui i = 0; i < e[x].size(); i++) {
        int y = e[x][i];
        if (y == f) continue;
        dfs1(y, x);
        s[x] += s[y];
        g[x] = max(g[x], s[y]);
        if (s[y] > (n >> 1)) fg = 0;
    }
    if (n - s[x] > (n >> 1)) fg = 0;
    if (fg) rt = x;
}

void dfs2(int x, int f) {
    add(c1, s[f], -1);
    add(c1, n - s[x], 1);
    if (x ^ rt) {
        ans += x * ask(c1, n - 2 * g[x]);
        ans -= x * ask(c1, n - 2 * s[x] - 1);
        ans += x * ask(c2, n - 2 * g[x]);
        ans -= x * ask(c2, n - 2 * s[x] - 1);
        if (!z[x] && z[f]) z[x] = 1;
        ans += rt * (s[x] <= n - 2 * s[z[x]?v:u]);
    }
    add(c2, s[x], 1);
    for (ui i = 0; i < e[x].size(); i++) {
        int y = e[x][i];
        if (y == f) continue;
        dfs2(y, x);
    }
    add(c1, s[f], 1);
    add(c1, n - s[x], -1);
    if (x ^ rt) {
        ans -= x * ask(c2, n - 2 * g[x]);
        ans += x * ask(c2, n - 2 * s[x] - 1);
    }
}

inline void solve() {
    rd(n);
    for (int i = 1; i <= n; i++) e[i].clear();
    for (int i = 1, x, y; i < n; i++) rd(x), rd(y), e[x].pb(y), e[y].pb(x);
    ans = 0;
    dfs1(1, 0);
    dfs1(rt, 0);
    u = v = 0;
    for (ui i = 0; i < e[rt].size(); i++) {
        int x = e[rt][i];
        if (s[x] > s[v]) v = x;
        if (s[v] > s[u]) swap(u, v);
    }
    for (int i = 1; i <= n + 1; i++) c1[i] = c2[i] = 0;
    for (int i = 0; i <= n; i++) add(c1, s[i], 1), z[i] = 0;
    z[u] = 1;
    dfs2(rt, 0);
    print(ans);
}

int main() {
    int T;
    rd(T);
    while (T--) solve();
    return 0;
}