【学习笔记】线段树合并

· · 算法·理论

Foreword

通常我们会遇到树上或图上问题,问题需要对每个节点开线段树维护且需要合并的时候,就会用到线段树合并,就比如将子节点线段树的信息合并到父节点上。

Main

也就是说我们会将两颗维护的区间相同的线段树对应节点的信息合并,前置知识:动态开点线段树。

我们记需要合并的线段树为 S_a,S_b,在遍历到某节点,若两棵树中一对应节点为空,返回另一棵中存在的节点,这利用了动态开点的思想,在遍历到叶子结点时,根据题目需要进行操作。

通常的,假设我们将 S_b 合并到 S_a 上,这无疑会破坏 S_a 所有包含当前节点区间的节点信息,这样的情况也就对应着 S_a,S_b 当前两棵树上都没有节点,所以可以选择新建节点,这又利用了动态开点线段树的特点。

同时需要注意合并线段树的同时根据题目需要向上更新信息。

主要代码:

int merge (int rtA, int rtB, int l, int r) {
    if (!rtA) return rtB;
    if (!rtB) return rtA;
    // 返回存在的节点
    if (l == r) {
        // 根据题目需要进行更新操作
        return rtA;
    }
    int mid = (l + r) >> 1;
    ls(rtA) = merge (ls(rtA), ls(rtB), l, mid); // 合并左儿子
    rs(rtA) = merge (rs(rtA), rs(rtB), mid + 1, r); // 合并右儿子
    update (rtA); // 向上更新信息
    return rtA;
}

这是将一颗线段树合并到另一棵上的实现,如果说新建节点的话:

int merge (int rtA, int rtB, int l, int r) {
    if (!rtA) return rtB;
    if (!rtB) return rtA;
    // 返回存在的节点
    if (l == r) {
        // 根据题目需要进行更新操作
        return rtA;
    }
    int mid = (l + r) >> 1, newNode = ++ tot; // 新建节点
    ls(newNode) = merge (ls(rtA), ls(rtB), l, mid); // 合并左儿子
    rs(newNode) = merge (rs(rtA), rs(rtB), mid + 1, r); // 合并右儿子
    update (newNode); // 向上更新信息
    return rtA;
}

总复杂度 O(n \log n)

awa。

Ex.1 P4556

模板题。

题意就是类似于有 m 个区间加操作,只不过是在树上,给两点之间路径进行区间加操作。

利用差分转化成单点修改,同时将路径分成 (x,\operatorname{lca}(x,y)),(\operatorname{lca}(x,y),y) 两条,这里要求救济粮种类,所以考虑使用权值线段树。

具体为,我们遍历整棵树时,将子节点的线段树合并到父节点,每次更新答案就从父节点查询到答案,每次单点修改 x,y,\operatorname{lca}(x,y),fa_{\operatorname{lca}(x,y)}

#include <bits/stdc++.h>
#define int long long
using namespace std;
constexpr int MAXN = 1e5 + 10;
int n, m, head[MAXN], idx, Ans[MAXN];
struct node {
    int v, nxt;
} edge[MAXN << 1];
inline void Addedge (int u, int v) { edge[++ idx].v = v, edge[idx].nxt = head[u], head[u] = idx; }
namespace Segmentree {
    int ind = 0, p[MAXN][25], tr[MAXN], depth[MAXN];
    struct SMT {
        int lson, rson, nod, Sum;
    } Seg[MAXN << 6];
    #define ls(rt) Seg[rt].lson
    #define rs(rt) Seg[rt].rson
    inline void update (int rt) {
        if (Seg[ls(rt)].Sum < Seg[rs(rt)].Sum) 
            Seg[rt].Sum = Seg[rs(rt)].Sum, Seg[rt].nod = Seg[rs(rt)].nod;
        else
            Seg[rt].Sum = Seg[ls(rt)].Sum, Seg[rt].nod = Seg[ls(rt)].nod;
    }
    int merge (int rtA, int rtB, int l, int r) {
        if (!rtA) return rtB;
        if (!rtB) return rtA;
        if (l == r) {
            Seg[rtA].Sum += Seg[rtB].Sum;
            return rtA;
        }
        int mid = (l + r) >> 1;
        ls(rtA) = merge (ls(rtA), ls(rtB), l, mid);
        rs(rtA) = merge (rs(rtA), rs(rtB), mid + 1, r);
        update (rtA);
        return rtA;
    }
    int Add (int pos, int rt, int l, int r, int k) {
        if (!rt) rt = ++ ind;
        if (l == r) {
            Seg[rt].Sum += k;
            Seg[rt].nod = pos;
            return rt;
        }
        int mid = (l + r) >> 1;
        if (pos <= mid)
            ls(rt) = Add (pos, ls(rt), l, mid, k);
        else
            rs(rt) = Add (pos, rs(rt), mid + 1, r, k);
        update (rt);
        return rt;
    }
    void dfs (int u) {
        for (int i = 0; i <= 20; i ++)
            p[u][i + 1] = p[p[u][i]][i];
        for (int i = head[u]; ~i; i = edge[i].nxt) {
            int v = edge[i].v;
            if (v == p[u][0]) continue;
            depth[v] = depth[u] + 1;
            p[v][0] = u;
            dfs (v);
        }
    }
    int LCA (int u, int v) {
        if (depth[u] < depth[v]) swap (u, v);
        int del = depth[u] - depth[v];
        for (int i = 20; ~i; i --) {
            if ((del >> i) & 1)
                u = p[u][i];
        }
        if (u == v) return u;
        for (int i = 20; ~i; i --) {
            if (p[u][i] != p[v][i])
                u = p[u][i], v = p[v][i];
        }
        return p[u][0];
    }
} using namespace Segmentree; 

void dfsCalc (int u) {
    for (int i = head[u]; ~i; i = edge[i].nxt) {
        int v = edge[i].v;
        if (v == p[u][0]) continue;
        dfsCalc (v);
        tr[u] = merge (tr[u], tr[v], 1, 1e5);
    }
    Ans[u] = Seg[tr[u]].nod;
    if (!Seg[tr[u]].Sum) Ans[u] = 0;
}

signed main() {
    cin.tie (0) -> sync_with_stdio (0);
    cout.tie (0) -> sync_with_stdio (0);
    cin >> n >> m;
    memset (head, -1, sizeof head);
    for (int i = 1; i < n; i ++) {
        int u, v; cin >> u >> v;
        Addedge (u, v), Addedge (v, u);
    }
    dfs (1);
    while (m --) {
        int x, y, z, t; cin >> x >> y >> z;
        t = LCA (x, y);
        tr[x] = Add (z, tr[x], 1, 1e5, 1);
        tr[y] = Add (z, tr[y], 1, 1e5, 1);
        tr[t] = Add (z, tr[t], 1, 1e5, -1);
        tr[p[t][0]] = Add (z, tr[p[t][0]], 1, 1e5, -1);
    }
    dfsCalc (1);
    for (int i = 1; i <= n; i ++)
        cout << Ans[i] << endl;
    return 0;
}

Ex.2 P3714

看到这类树上路径问题先想到点分治,但是要维护长度 [l,r] 内的路径,考虑线段树。

记当前分治中心为 p,则我们给路径分类:

  1. 路径两端在不同子树内,且两端到分治中心的颜色不同。
  2. 路径两端在不同子树内,且两端到分治中心的颜色相同。
  3. 路径一端为 p,一端在子树内。
  4. 两端在同一子树内。

后面两种情况是正常递归就可以解决,前面两种比较棘手。

我们可以开两颗线段树存储所有情况,按照边的颜色排序,分别存已经被遍历的颜色的路径最大权值和当前遍历的颜色的路径最大权值,在查询当前遍历颜色需要减去当前节点的值。

#include <bits/stdc++.h>
#define int long long
using namespace std;
constexpr int MAXN = 2e5 + 33, inf = 1e18;
int n = 0, m = 0, L = 0, R = 0, Col[MAXN], head[MAXN], idx = 0, tot = 0, Ans = -inf;
struct Tree {
    int v, c, nxt;
    Tree (int v = 0, int c = 0, int nxt = 0):v(v), c(c), nxt(nxt){};
} edge[MAXN << 1];
inline void Addedge (int u, int v, int c) { edge[++ idx] = Tree (v, c, head[u]), head[u] = idx; }

namespace Segmentree {
    struct SMT {
        int lson, rson, Val, maxVal;
        void Clear() { lson = rson = 0, maxVal = Val = -inf; }
    } Seg[MAXN * 80];
    #define lson(rt) Seg[rt].lson
    #define rson(rt) Seg[rt].rson
    inline void pushup (int rt) { Seg[rt].maxVal = max (Seg[lson(rt)].maxVal, Seg[rson(rt)].maxVal); }
    int Merge (int rtA, int rtB, int l, int r) {
        if (!rtA) return rtB;
        if (!rtB) return rtA;
        if (l == r) {
            Seg[rtA].maxVal = Seg[rtA].Val = max (Seg[rtA].Val, Seg[rtB].Val);
            Seg[rtB].Clear();
            return rtA;
        }
        int mid = (l + r) >> 1;
        lson(rtA) = Merge (lson(rtA), lson(rtB), l, mid);
        rson(rtA) = Merge (rson(rtA), rson(rtB), mid + 1, r);
        Seg[rtB].Clear(), pushup (rtA);
        return rtA;
    }
    void Update (int l, int r, int p, int rt, int k) {
        if (l == r) {
            Seg[rt].maxVal = Seg[rt].Val = max (Seg[rt].Val, k);
            return;
        }
        int mid = (l + r) >> 1;
        if (p <= mid) {
            if (!lson(rt)) lson(rt) = ++ tot;
            Update (l, mid, p, lson(rt), k);
        } else {
            if (!rson(rt)) rson(rt) = ++ tot;
            Update (mid + 1, r, p, rson(rt), k);
        }
        pushup (rt);
    }
    int Query (int ql, int qr, int l, int r, int rt) {
        if (ql <= l && qr >= r)
            return Seg[rt].maxVal;
        int mid = (l + r) >> 1, tmpVal = -inf;
        if (ql <= mid && lson(rt)) 
            tmpVal = max (tmpVal, Query (ql, qr, l, mid, lson(rt)));
        if (qr > mid && rson(rt))
            tmpVal = max (tmpVal, Query (ql, qr, mid + 1, r, rson(rt)));
        return tmpVal;
    }
    void ClearTree (int rt) {
        if (lson(rt)) ClearTree (lson(rt));
        if (rson(rt)) ClearTree (rson(rt));
        Seg[rt].Clear();
    }
} using namespace Segmentree;

namespace DivTree {
    int rtA = 0, rtB = 0, rt = 0; bool tag[MAXN];
    int getSiz (int u, int fa) {
        int Siz = 1;
        for (int i = head[u]; ~i; i = edge[i].nxt) {
            int v = edge[i].v;
            if (v == fa || tag[v]) continue;
            Siz += getSiz (v, u);
        }
        return Siz;
    }
    int findrt (int u, int fa, int totSiz) {
        int mxSiz = 0, sumSiz = 1;
        for (int i = head[u]; ~i; i = edge[i].nxt) {
            int v = edge[i].v;
            if (v == fa || tag[v]) continue;
            int Sz = findrt (v, u, totSiz);
            sumSiz += Sz;
            mxSiz = max (mxSiz, Sz);
        }
        mxSiz = max (mxSiz, totSiz - sumSiz);
        if (mxSiz <= totSiz / 2) rt = u;
        return sumSiz;
    }
    void getDis (int rt, int u, int fa, int Val, int Len, int col, int delta) {
        if (Len > R) return;
        Ans = max (Ans, Query (max (1LL, L - Len), max (1LL, R - Len), 1, R, rtA) + Val - delta);
        Ans = max (Ans, Query (max (1LL, L - Len), max (1LL, R - Len), 1, R, rtB) + Val);
        if (Len >= L) Ans = max (Ans, Val);
        Update (1, R, Len, rt, Val);
        for (int i = head[u]; ~i; i = edge[i].nxt) {
            int v = edge[i].v, c = edge[i].c;
            if (v == fa || tag[v]) continue;
            if (c == col) {
                getDis (rt, v, u, Val, Len + 1, col, delta);
            } else {
                getDis (rt, v, u, Val + Col[c], Len + 1, c, delta);
            }
        }
    }
    void Solve (int u) {
        findrt (u, 0, getSiz (u, 0));
        tag[rt] = true;
        priority_queue < pair<int,int> > q;
        for (int i = head[rt]; ~i; i = edge[i].nxt) {
            int v = edge[i].v, c = edge[i].c;
            if (tag[v]) continue;
            q.push (make_pair (c, v)); 
        }
        int lstNode = 0;
        rtB = ++ tot;
        while (!q.empty()) {
            int c = q.top().first, v = q.top().second; q.pop();
            if (c != lstNode)
                Merge (rtB, rtA, 1, R), rtA = ++ tot;
            int _rt = ++ tot;
            getDis (_rt, v, rt, Col[c], 1, c, Col[c]);
            Merge (rtA, _rt, 1, R);
            lstNode = c;
        }
        tot = 0, ClearTree (rtA), ClearTree (rtB);
        for (int i = head[rt]; ~i; i = edge[i].nxt) {
            int v = edge[i].v;
            if (tag[v]) continue;
            Solve (v);
        }
    }
} using namespace DivTree;

signed main() {
    scanf ("%lld %lld %lld %lld", &n, &m, &L, &R);
    memset (head, -1, sizeof head);
    for (int i = 0; i < MAXN * 10; i ++) 
        Seg[i].maxVal = Seg[i].Val = -inf;
    for (int i = 1; i <= m; i ++)
        scanf ("%lld", &Col[i]);
    for (int i = 1, u, v, c; i < n; i ++) {
        scanf ("%lld %lld %lld", &u, &v, &c);
        Addedge (u, v, c), Addedge (v, u, c);
    }
    Solve (1);  
    printf ("%lld\n", Ans);
    return 0;
}

Summary

推荐题单