题解 P5327 【[ZJOI2019]语言】
前言
先来扯些真实的废话。
这八成是考场上最可做的题,原因有以下:
众所周知,数据范围越不正常的题越毒瘤,本题的数据范围就很正常。
往年省选包括今年 round 1 最简单的题中往往都有线段树,本题感性上就离不开线段树。
九条在出完 2、3、1 的开题顺序之后,喜欢出 2、1、3 的开题顺序。
遗憾的是,由于 round 1 的爆炸,考场上心态有点崩,智商骤降为 0,导致没做出此题。
可能我永远属于入场思维受阻、出场茅塞顿开的那种人吧。
解题思路
不难想到各种 听说会被九条定向卡。
实际上有一种巧妙的
首先容易发现以下性质:
考虑这个生成树是怎么来的?可以进一步得出:
如果路径
s \to t 包含点u ,s,\,t 则是u 的两个极远点。
举个例子,假设有三条路径分别是
锁定
于是,我们需要事先思考子问题:给定树上若干的点,如何求最小生成树大小?
这应该是个经典问题。可能要用到一点建虚树的思想。
为了方便,我们硬点存在极远点
我们给这些点按照
上面那段可能有点绕,强烈建议画个图感受一下!
暴力加点似乎不容易被利用,此时最擅长分治的线段树终于重磅登场了。
仍然是以
白白套一个线段树不是多
想象每个结点都有一棵线段树。注意到有一条路径
慢着……线段树打标记是可以,但每个结点要把自己的信息继承给父亲,这怎么做?
动态开点 + 线段树合并!
至此,[ZJOI2019]语言 的大致做法讲完了。
代码实现
刚才求的
#include <cmath>
#include <cstdio>
#include <vector>
#include <algorithm>
const int N = 200005, V = 6400005, L = 18;
int n, m, tms, o[N], ft[N], dep[N], dfn[N], st[L][N];
std::vector<int> to[N], del[N];
long long ans;
inline int getLca(int u, int v);
struct SegmentTree {
int tot, c[V], f[V], s[V], t[V], ls[V], rs[V], rt[N];
inline void pushUp(int u) {
f[u] = f[ls[u]] + f[rs[u]] - dep[getLca(t[ls[u]], s[rs[u]])];
s[u] = s[ls[u]] ? s[ls[u]] : s[rs[u]];
t[u] = t[rs[u]] ? t[rs[u]] : t[ls[u]];
}
inline int query(int u) { return f[u] - dep[getLca(s[u], t[u])]; }
void modify(int &u, int l, int r, int p, int x) {
if (!u) { u = ++tot; }
if (l == r) {
c[u] += x; f[u] = c[u] ? dep[p] : 0; s[u] = t[u] = c[u] ? p : 0;
return;
}
int mid = l + r >> 1;
if (dfn[p] <= mid) { modify(ls[u], l, mid, p, x); }
else { modify(rs[u], mid + 1, r, p, x); }
pushUp(u);
}
void merge(int &u, int v, int l, int r) {
if (!u || !v) { u |= v; return; }
if (l == r) {
c[u] += c[v]; f[u] |= f[v]; s[u] |= s[v]; t[u] |= t[v];
return;
}
int mid = l + r >> 1;
merge(ls[u], ls[v], l, mid); merge(rs[u], rs[v], mid + 1, r); pushUp(u);
}
} smt;
void build() {
for (int i = 1; i <= tms; i++) { o[i] = log(i) / log(2) + 1e-7; }
for (int i = 1; i <= o[tms]; i++) {
for (int j = 1, u, v; j + (1 << i) - 1 <= tms; j++) {
u = st[i - 1][j]; v = st[i - 1][j + (1 << i - 1)];
st[i][j] = dep[u] < dep[v] ? u : v;
}
}
}
inline int getLca(int u, int v) {
if (!u || !v) { return 0; } u = dfn[u]; v = dfn[v];
if (u > v) { std::swap(u, v); }
int d = o[v - u + 1]; u = st[d][u]; v = st[d][v - (1 << d) + 1];
return dep[u] < dep[v] ? u : v;
}
void dfs(int u, int fa) {
ft[u] = fa; dep[u] = dep[fa] + 1; st[0][++tms] = u; dfn[u] = tms;
for (auto v : to[u]) { if (v != fa) { dfs(v, u); st[0][++tms] = u; } }
}
void solve(int u) {
for (auto v : to[u]) { if (v != ft[u]) { solve(v); } }
for (auto v : del[u]) { smt.modify(smt.rt[u], 1, tms, v, -1); }
ans += smt.query(smt.rt[u]); smt.merge(smt.rt[ft[u]], smt.rt[u], 1, tms);
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 2, u, v; i <= n; i++) {
scanf("%d%d", &u, &v);
to[u].push_back(v); to[v].push_back(u);
}
dfs(1, 0); build();
for (int u, v, lca; m; m--) {
scanf("%d%d", &u, &v); lca = getLca(u, v);
smt.modify(smt.rt[u], 1, tms, u, 1); smt.modify(smt.rt[u], 1, tms, v, 1);
smt.modify(smt.rt[v], 1, tms, u, 1); smt.modify(smt.rt[v], 1, tms, v, 1);
del[lca].push_back(u); del[ft[lca]].push_back(u);
del[lca].push_back(v); del[ft[lca]].push_back(v);
}
solve(1); printf("%lld\n", ans >> 1);
return 0;
}
尾注
看完后您还可真别说这题简单,考场上这么多算法的选择,鬼知道偏偏碰上了这一种呢?