「学习笔记」可持久化线段树

· · 算法·理论

权值线段树 & 可持久化线段树

Part 1:权值线段树

在了解可持久化线段树前,必须了解权值线段树

权值线段树是一种十分简单的线段树,对于对应区间 [l,r] 的节点 i,该节点保存给定数组中数值在区间 [l,r] 的数据个数,不妨将其记为 s_{i}

利用权值线段树,我们可以根据每一个节点的大小,不断向下查找,就可以找到任意排名的数了。

Part 2:可持久化线段树

可持久化线段树可以解决这样的问题:

给定 n 个整数构成的序列 a,将对于指定的闭区间 [l,r] 查询其区间内的第 k 小值。

数据范围:1 \le n,m \le 2×10^5 0 \le a_i \le 10^9

这个题目与前文提到的权值线段树的功能有几个区别:

一个个解决这些问题。

问题 1:实现可持久化

我们需要存储 n 个版本,如果单纯地开 n 个线段树的话,肯定会爆空间。

我们注意到,根据线段树的特点,我们每次新添加一个节点,从树根到叶子节点,只会经过 \log n 个节点,其余节点的信息并不会受到影响。

因此,我们在插入一个新的数值时,不需要全部重开,只需要新增 \log n 个节点,其余的指回上一个版本的对应节点即可。

下图^{[1]} 为加入一个新数值的示意图,其中红色的为新的节点和边。

具体来说,就是从前一个版本的根节点出发,并且每次新开一个节点作为新的节点。

首先,该新节点先继承前一个版本的对应节点(这里的对应代表记录相同的值域范围,例如上图中的节点 3 和节点 11)的左右孩子,然后根据新加入的数值的大小,递归进行该操作,并修改改当前节点的左/右孩子即可。

然后,对于查找,我们可以记录下每个版本的根节点编号,这样每次从对应编号的节点往下查找即可。

这样,我们就顺利解决了可持久化以及空间的问题。

Tips:显然,实现可持久化线段树需要动态开点。在本题的数据范围下,大约需要 2 \times 10^6 个节点^{[2]}

问题 2:在指定区间内进行查找

前置知识:前缀和思想。

很显然,我们可以在任意 [1,r] 的区间进行查找,那如果改为 [l,r] 呢?

目前,我们可以求解 [1,l] 以及 [1,r]

显然,我们也可以处理 [1,l-1]

于是我们可以利用前缀和的思想将其联系在一起。

具体地说,我们现在要查找 [l,r] 内的第 k 小,我们可以设版本 r(即为插入第 r 个数后的版本)与版本 l-1 对应起来。

不妨设这两个对应的节点分别是 ij,这两个节点所记录的值域范围都为 [a,b],中间值为 mid(对应的含义同上)。

根据定义,节点 i 的左孩子的大小即为区间 [1,l-1] 中,比 mid 要小的数据个数,记为 x

同理,节点 j 的左孩子的大小即为区间 [1,r] 中,比 mid 要小的数据个数,记为 y

那么到这里就很清晰了,区间 [l,r] 中,比 mid 要小的数据个数就是 y-x 啦!

之后的操作就和普通的权值线段树的操作相同了:若 k \le y-x,就走左边,否则就走右边。

至此,我们就基本实现了可持久化线段树啦!

其他的细节

Part 3:例题

模板

首先是板子。

三道题,同一份代码(数据范围还是要注意下的)!

Code #1~3

#include<bits/stdc++.h>
using namespace std;

const int N=2e5+5;
//          去重后的长度 
int n,m,idx,len;
//一个是原数组,另一个是去重 + 排序后的 
int a[N],b[N];
int ls[N<<5],rs[N<<5],sum[N<<5],ver[N<<5];

int find(int x){ //离散找排名 
    return lower_bound(b+1,b+1+len,x)-b;
}

//建树 
int build(int l,int r){
    int now=++idx; //新开一个节点 
    if(l==r) return now;
    int mid=(l+r)>>1;
    ls[now]=build(l,mid);
    rs[now]=build(mid+1,r);
    return now;
}

//插入一个新数值 
int insert(int l,int r,int pre,int k){
    int now=++idx;
    //先继承对应的节点的左右孩子,更新大小 
    ls[now]=ls[pre],rs[now]=rs[pre],sum[now]=sum[pre]+1;
    if(l==r) return now;
    int mid=(l+r)>>1;
    if(k<=mid){
        ls[now]=insert(l,mid,ls[now],k);
    }
    else{
        rs[now]=insert(mid+1,r,rs[now],k);
    }
    return now;
}

//     当前节点   之前版本的对应节点 
int query(int now,int pre,int l,int r,int k){
    if(l==r) return l;
    int mid=(l+r)>>1;
    //前缀和思想 
    int x=sum[ls[now]]-sum[ls[pre]];
    if(k<=x){
        return query(ls[now],ls[pre],l,mid,k);
    }
    else{
        return query(rs[now],rs[pre],mid+1,r,k-x);
    }
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);

    cin>>n>>m;
    for(int i=1;i<=n;i++){
        cin>>a[i];
        b[i]=a[i];
    }
    sort(b+1,b+1+n);
    len=unique(b+1,b+1+n)-b-1;
    ver[0]=build(1,len);

    for(int i=1;i<=n;i++){
        ver[i]=insert(1,len,ver[i-1],find(a[i]));
    }

    int l,r,k;
    for(int i=1;i<=m;i++){
        cin>>l>>r>>k;
        cout<<b[query(ver[r],ver[l-1],1,len,k)]<<"\n";
    }

    return 0;   
} 

P3567 [POI2014] KUR-Couriers

题目大意:

给一个长度为 n 的正整数序列 a。共有 m 组询问,每次询问一个区间 [l,r],是否存在一个数在 [l,r] 中出现的次数严格大于一半。如果存在,输出这个数,否则输出 0

数据范围:1 \le n,m \le 5\times 10^51 \le a_i \le n

本题也可以用可持久化线段树解决。

修改一下查询函数即可通过此题。

具体来说,我们由原先的求排名变为求出现次数。

我们可以利用前缀和的思想(与原先无区别),求出区间内大于中间值和小于中间值的数据个数。然后分别判断是否大于总数的一半。如果是就继续在该子树中查找。否则直接返回 0,表示不存在。

由于本题数据的特殊性,不离散化也是可以的。

Code #4

#include<bits/stdc++.h>
using namespace std;

const int N=5e5+5;
//          去重后的长度 
int n,m,idx;
int a[N];
int ls[N<<5],rs[N<<5],sum[N<<5],ver[N<<5];

//建树 
int build(int l,int r){
    int now=++idx; //新开一个节点 
    if(l>=r) return now;
    int mid=(l+r)>>1;
    ls[now]=build(l,mid);
    rs[now]=build(mid+1,r);
    return now;
}

//插入一个新数值 
int insert(int l,int r,int pre,int k){
    int now=++idx;
    //先继承对应的节点的左右孩子,更新大小 
    ls[now]=ls[pre],rs[now]=rs[pre],sum[now]=sum[pre]+1;
    if(l>=r) return now;
    int mid=(l+r)>>1;
    if(k<=mid){
        ls[now]=insert(l,mid,ls[now],k);
    }
    else{
        rs[now]=insert(mid+1,r,rs[now],k);
    }
    return now;
}

//     当前节点   之前版本的对应节点 
int query(int now,int pre,int l,int r,int k){
    if(l>=r) return l;
    int mid=(l+r)>>1;
    //前缀和思想 
    int x=sum[ls[now]]-sum[ls[pre]],y=sum[rs[now]]-sum[rs[pre]];
    //就这里改了一下 
    if(x>k){ //如果总数大于一半,就继续找 
        return query(ls[now],ls[pre],l,mid,k);
    }
    else if(y>k){
        return query(rs[now],rs[pre],mid+1,r,k);
    }
    return 0;
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);

    cin>>n>>m;
    for(int i=1;i<=n;i++){
        cin>>a[i];
    }
    ver[0]=build(1,n);

    for(int i=1;i<=n;i++){
        ver[i]=insert(1,n,ver[i-1],a[i]);
    }

    int l,r;
    for(int i=1;i<=m;i++){
        cin>>l>>r;
        cout<<query(ver[r],ver[l-1],1,n,(r-l+1)/2)<<"\n";
    }

    return 0;   
} 

P4587 [FJOI2016] 神秘数

题目大意不复述了。

我们可以先假设集合中的数都是排好序的,然后按顺序添加数字。

注意:以下所表示的区间均为值域区间。

设目前可以最大且满足题意的区间为 [1,p],也就是说区间 [1,p] 中的数都可以表示。设目前已经添加到集合中最大的数为 q,因为我们是按照数的大小添加数的,所以所有小于等于 q 的数均已被添加。

初始时,p=q=0

注意到询问的答案就是在所有数都添加完后,p+1 的值。

那么我们每次可以添加的数是多少呢?

首先,可添加的最小值肯定是 q+1,因为所有小于等于 q 的数均已被添加。

然后确定一下右边界。由于我们上面说过,在目前答案为 p+1 的情况下,加入大于 p+1 的数对答案没有贡献。

综上,我们可以加入的数的区间就是 [q+1,p+1]

之后,由于我们把 [q+1,p+1] 内的所有数加入集合,那么 q \leftarrow p+1

设我们加入的数的总和为 x,那么答案区间就扩大为 [1,p+x],也就是使 p \leftarrow p+x

众所周知,线段树可以实现区间求和,同样地,我们可以将其可持久化,实现求区间和的操作。

Tips:

Code #5

#include<bits/stdc++.h>
#define i_ak_ioi true 
using namespace std;

const int N=1e5+5,M=1e9;
int n,m,idx,len;
int a[N],b[N];
int ls[N<<5],rs[N<<5],ver[N<<5];
long long sum[N<<5];

//插入一个新数值 
int insert(int l,int r,int pre,long long k){
    int now=++idx;
    //先继承对应的节点的左右孩子,更新数值 
    ls[now]=ls[pre],rs[now]=rs[pre],sum[now]=sum[pre]+k;
    if(l==r) return now;
    int mid=(l+r)>>1;
    if(k<=mid){
        ls[now]=insert(l,mid,ls[now],k);
    }
    else{
        rs[now]=insert(mid+1,r,rs[now],k);
    }
    return now;
}

//     当前节点   之前版本的对应节点  当前区间 目标区间 
long long query(int now,int pre,int L,int R,int l,int r){
    //与普通线段树操作大致相同(只不过记录的是权值) 
    if(l<=L&&R<=r) return sum[now]-sum[pre];
    long long ans=0;
    int mid=(L+R)>>1;
    if(l<=mid) ans+=query(ls[now],ls[pre],L,mid,l,r);
    if(r>=mid+1) ans+=query(rs[now],rs[pre],mid+1,R,l,r);
    return ans;
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);

    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>a[i];
    }
    for(int i=1;i<=n;i++){
        ver[i]=insert(1,M,ver[i-1],a[i]);
    }

    cin>>m;

    int l,r;
    for(int i=1;i<=m;i++){
        cin>>l>>r;
        //一个是当前最大值,一个是能表示的最大值 
        long long maxn=0,ans=0;
        while(i_ak_ioi == true){
            long long res=query(ver[r],ver[l-1],1,M,maxn+1,ans+1);
            if(res){
                maxn=ans+1;
                ans+=res;
            }
            else break;
        }
        cout<<ans+1<<"\n";
    }

    return 0;   
} 

Ending

本文的思路来源于 OI-Wiki,但是进行了细化与改良,把不是很清楚的地方用自己的语言解释出来。

欢迎指出文中的不当之处!

注:

$[2]$:具体计算过程参见 [**OI-Wiki**](https://oi-wiki.org/ds/persistent-seg/)。 $\text{Upd on 2025/1/24}$ :增加了一道例题。 --- $$\text{A Fascinating Ending.}$$