树链剖分详解

eros1on

2019-02-21 15:12:28

题解

博客食用更佳

Introduction

当我们想要同时完成

这两种操作时,会怎么做呢?

相信我们对于这两种操作分别都会解决,但是如果一旦放到同一道题里,我们就会束手无策了。

树链剖分便由此而来了。

前置技能:线段树(重要) & 倍增求 LCA(不必须)

Definition & Steps

树剖是通过将一棵有根树分成多个链,然后利用各种数据结构(如线段树等)来维护这棵链,从而间接地维护这棵树。

首先,为了方便我们对树剖的理解,我们需要知道一些很基础很重要的概念:

上一张百度百科的图:

如果我们想知道1的重儿子是谁,那么我们只需递归地求一下它每棵子树的size即可;

同时,我们还可以顺便维护出所有结点的 fadep

递归伪代码如下:

void dfs1(int u, int f) // u 当前结点,f 是 u 的父亲结点
    size[u] = 1
    for each v that connects to u : // u 的子结点
        if(v != f) // v 不是 u 的父亲
            fa[v] = u // 说明 v 是 u 的子结点
            dep[v] = dep[u] + 1 // 深度维护
            dfs(v, u) // 继续递归
            size[u] += size[v] // 将子树的 size 加到这棵树的 size 上
            if(size[wson[u]] < size[v]) wson[u] = v // 更新重儿子

可以手动模拟一下

这样,我们就求出了每个结点的重儿子 wson

特别地,每个叶子结点的重儿子都为 0

每个标红点的结点都是一条重链的链首(top),

而加粗的边则是重链。

比如,1=>4=>9=>13=>14 就是一条重链,而 2=>6=>11 也是另一条重链;

另外我们会注意到,除了根节点以外的所有重链链首都是轻儿子,

例如 top[3]\ =\ 3

回顾一下,还有哪些信息没有维护呢?

dfn,\ pre,\ top

其实,这三个只需要另一个递归函数就足够啦!

void dfs2(int u, int tp) // u 是当前结点,tp 是 u 所在重链的链首
    top[u] = tp
    dfn[u] = ++tot // tot 是时间戳
    pre[tot] = u // pre 是 dfn 的反函数
    if(wson[u]) dfs2(wson[u], tp); // 只要 u 有重儿子,那就可以继续下去
    for each v that connects to u : 
        if(v != fa[u] && v != wson[u]) // v 是 u 的轻儿子
            dfs2(v, v); // 轻儿子是重链的链首

图中边上的数字就是 dfs2 递归地顺序,不明白的可以参考一下。

检测一下你有没有明白:2dfn 是几?7pre 是几?

**好了,树链剖分阶段到此结束。** 将这棵树剖分成了许多链,现在就可以用线段树维护了~ > 写上了一段带修改查询的线段树模板。。。 现在轮到处理问题的阶段了。 为了便于理解,我们就先实现两个最基础的操作吧: 路径点权求和、路径点权修改 大体思路很简单,就是把这条路径分成若干条原来的重链,然后依次实现。 举个栗子:求 $9$ 到 $11$ 路径上的点权之和(还是上面那张图。。) ![图片没了吗?可以到百度百科上看](https://s2.ax1x.com/2019/02/02/k8LSfK.jpg) 为了简单起见,我们暂定每个点的点权是它的编号(不是 $dfn$)。 这个和倍增求 $LCA$ 有点像。 先找到 $dep$ 更深的结点,$11$($dep[11]\ >\ dep[9]$) 将答案 $res$ 加上 $11$ 到 $top[11]$ 的点权和 我们可以用事先维护好的 $dfn$ 来帮助 由于 $top[11]\ =\ 2$,所以我们在求 $2=>11$ 的和 我们发现因为这是一条重链,所以这条路径上的每个点的 $dfn$ 都是连续的! 线段树有用了,注意 $dfn[top[11]] < dfn[11]

res += query(root, dfn[top[11]], dfn[11])

这是,操作变成了求 res\ +\ 9=>11 的路径点权和。

所以我们将 11 变为 fa[top[11]]\ =\ 1

这是,我们发现 19 共链了。

所以像刚才一样,

res += query(root, dfn[1], dfn[9])

即可!

将上面的步骤转化成代码:

void Qsum(int u, int v) {
    int left, right, res = 0;
    while(top[u] != top[v]) { // 只要不共链
        if(dep[top[u]] > dep[top[v]]) swap(u, v); // 每次都要让 top 深一些的往上跳
        res += query(root, dfn[top[v]], dfn[v]);
        v = fa[top[v]];
    } left = dfn[u], right = dfn[v];
    if(left > right) swap(left, right); // 注意大小关系!
    res += query(root, left, right);
    return res;
}

很简单吧!

修改操作类似,

vois Qchange(int u, int v, int d) { // u 到 v 的路径上点权加 d
    int left, right;
    while(top[u] != top[v]) {
        if(dep[top[u]] > dep[top[v]]) swap(u, v);
        change(root, dfn[top[v]], dfn[v], d);
        v = fa[top[v]];
    } left = dfn[u], right = dfn[v];
    if(left > right) swap(left, right);
    change(root, left, right, d);
}

除了这两个基础操作以外,还有一类比较常见的操作:子树。

其实这种操作比刚才两种还简单。

我们考虑以 2 为根的子树:

不难发现,这棵子树里所有结点的 dfn 是连续的。为什么呢?因为 dfn 是通过 dfs 得到的。

因此很显然在这整棵子树中, dfn 最小的那个结点就是子树的根。而最大的则是 dfn + size - 1

(其中 size 是这棵子树的大小)

所以代码只需要一句话,

inline int Tsum(int u) { // 询问以 u 为根的子树
    return query(root, dfn[u], dfn[u] + size[u] - 1);
}

inline void Tchange(int u, int d) { // 修改以 u 为根的子树
    change(root, dfn[u], dfn[u] + size[u] - 1, d);
}

Code

#include <cstdio>
#include <cstdlib>
#include <iostream>
using namespace std;
#define MAXN 100100
int n, m, dep[MAXN], fa[MAXN], wson[MAXN], top[MAXN], dfn[MAXN], tot, size[MAXN], pre[MAXN], w[MAXN];
struct edge { // 存图
    int v;
    edge *next;
} epool[MAXN << 1], *h[MAXN], *ecnt = epool;
struct node { // 存线段树
    int left, right, s, tag;
    node *ls, *rs;
    inline void seta(int x) { tag += x, s += (right - left + 1) * x;}
    inline void upd() { s = ls->s + rs->s;}
    inline void push() {
        if(tag) {
            if(ls) ls->seta(tag);
            if(rs) rs->seta(tag);
            tag = 0;
        }
    }
} pool[MAXN << 3], *root, *cnt = pool;
inline void addedge(int u, int v) { // 加边
    edge *p = ++ecnt, *q = ++ecnt;
    p->v = v, p->next = h[u], h[u] = p;
    q->v = u, q->next = h[v], h[v] = q;
}
inline void dfs1(int u, int f) {
    int v; size[u] = 1;
    for(edge *p = h[u]; p; p = p->next)
        if((v = p->v) != f) {
            fa[v] = u,
            dep[v] = dep[u] + 1;
            dfs1(v, u);
            size[u] += size[v];
            if(size[v] > size[wson[u]]) wson[u] = v;
        }
}
inline void dfs2(int u, int tp) {
    int v; top[u] = tp;
    dfn[u] = ++tot, pre[tot] = u;
    if(wson[u]) dfs2(wson[u], tp);
    for(edge *p = h[u]; p; p = p->next)
        if((v = p->v) != wson[u] && v != fa[u])
            dfs2(v, v);
}
inline void build(node *r, int left, int right) { // 线段树模板
    r->left = left, r->right = right;
    if(left == right) {
        r->s = w[pre[left]];
        return ;
    }
    int mid = (left + right) >> 1;
    node *ls = ++cnt, *rs = ++cnt;
    r->ls = ls, r->rs = rs;
    build(ls, left, mid), build(rs, mid + 1, right);
    r->upd();
}
inline void change(node *r, int left, int right, int d) { // 线段树模板
    if(r->left == left && r->right == right) {
        r->seta(d); return ;
    } r->push();
    if(r->ls->right >= right) change(r->ls, left, right, d);
    else if(r->rs->left <= left) change(r->rs, left, right, d);
    else change(r->ls, left, r->ls->right, d),
         change(r->rs, r->rs->left, right, d);
    r->upd();
}
inline int query(node *r, int left, int right) { // 线段树模板
    r->push();
    if(r->left == left && r->right == right) return r->s;
    if(r->ls->right >= right) return query(r->ls, left, right);
    else if(r->rs->left <= left) return query(r->rs, left, right);
    else return query(r->ls, left, r->ls->right) +
                query(r->rs, r->rs->left, right);
}
inline int Qsum(int u, int v) {
    int left, right, res = 0;
    while(top[u] != top[v]) {
        if(dep[top[u]] > dep[top[v]]) swap(u, v);
        res += query(root, dfn[top[v]], dfn[v]);
        v = fa[top[v]];
    } left = dfn[u], right = dfn[v];
    if(left > right) swap(left, right);
    res += query(root, left, right);
    return res;
}
inline void Qchange(int u, int v, int d) {
    int left, right;
    while(top[u] != top[v]) {
        if(dep[top[u]] > dep[top[v]]) swap(u, v);
        change(root, dfn[top[v]], dfn[v], d);
        v = fa[top[v]];
    } left = dfn[u], right = dfn[v];
    if(left > right) swap(left, right);
    change(root, left, right, d);
}
int main() {
    int op, u, v, d;
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i++) scanf("%d", &w[i]); // 读入点权
    for(int i = 1; i < n; i++) { // 读图
        scanf("%d%d", &u, &v);
        addedge(u, v);
    }
    dep[1] = 1; // 这步不能忘,否则 dfs1 没用!
    dfs1(1, 0), dfs2(1, 1);
    build(root = cnt, 1, n);
    while(m--) { // 操作
        scanf("%d%d%d", &op, &u, &v);
        if(op == 1)
            printf("%d\n", Qsum(u, v));
        else {
            scanf("%d", &d);
            Qchange(u, v, d);
        }
    }
    return 0;
}

最后附上此题代码