数据结构笔记【三】:主席树

· · 算法·理论

零、前言

网上和 OI-wiki 上对于主席树的介绍大多较简洁,对我这种入门水平选手非常不友好。
于是写一篇十分拙劣的面向普及选手的主席树。

一、简述

主席树,目前最普遍的说法是将它解释为“可持久化权值线段树“,因其发明人缩写为“HJT”而得名。
分析它的名字:

如图所示,假设我们要修改序列中的 3 号位,也就是图中树的 7 号节点,那么 1267 节点就会发生改变。
所以分裂它们,得到 16171819 号节点,然后在这四个节点上进行修改即可。
按照原线段树的左右子树关系,这个线段树就长这个样子。

可以发现,图中其实包含了两个线段树,只不过他们共用了一些节点,所以空间就大大减少。

2. 修改节点

具体实现这个步骤也差不多:

  1. 建立先前版本根节点的副本。
  2. 先以先前版本根节点的左右儿子作为副本的左右儿子。
  3. 找要修改的节点在左子树还是在右子树。
  4. 如果在左子树,就递归,顺便把副本的左儿子改成先前版本根节点左儿子的副本。
  5. 在右子树同理。
  6. 剩下的那个儿子就还是原先的儿子,因为剩下的那一棵子树没有变化,与先前版本共用。
    代码:
#define ls(p) st[p].l
#define rs(p) st[p].r

int update(int l,int r,int pos,int rt){
    int create=++cnt;
    ls(create)=ls(rt),rs(create)=rs(rt);
    if(l==r){/*进行相应修改*/;return create;}
    int mid=(l+r)>>1;
    if(pos<=mid) ls(create)=update(l,mid,pos,ls(create));
    else rs(create)=update(mid+1,r,pos,rs(create));
    /*pushup*/
    return create;
}

这个代码似乎有些难懂,下面还有一种写法:

#define ls(p) st[p].l
#define rs(p) st[p].r
void update(int l,int r,int pre,int &p){
    p=++cnt;
    ls(p)=ls(pre),rs(p)=rs(pre);
    if(l==r) {/*进行相应修改*/;return ;}
    int mid=(l+r)>>1;
    if(pos<=mid) update(l,mid,ls(pre),ls(p));
    else update(mid+1,r,rs(pre),rs(p));
    /*pushup*/
}

注意这里 p 是引用类型(不知道怎么用?建议重学 C++),以及第一个写法中的 rt 或第二个写法的 pre

最后主函数中我们定义 root_i 为第 i 个历史版本的线段树的根节点的编号。
显而易见,调用的时候应该调用 update(1,n,root[h],root[i]),意思是第 i 次操作时对编号为 h 的版本进行修改。
最好跟着代码手推一遍。

这个代码只是一个框架,具体的修改和 pushup 我并没有写。
其中 ls(p) 指节点 p 的左儿子,rs(p) 指节点 p 的右儿子,cnt 为总节点数。
关于这个写法,我们下一节再说。

3. 树结构

不加优化的主席树就是整体分裂,而加优化的主席树就是局部分裂。
而根据这种分裂方法,我们还能得出几个性质:

  1. 除叶子节点外,分裂之后的每个节点都有两个子节点。
  2. 每次修改后,根节点都会分裂。
  3. 每次修改会有 \log_2n 个节点分裂。
    根据第一个和第二个性质,可以得到,给定某个根节点的版本,就可以确定一棵线段树。
    再重申一遍,这棵线段树的某些节点可能与曾经的线段树共用,这代表这些节点在本次修改中没有被改变。
    根据第三个性质,可以得到可持久化线段树的空间复杂度为 O(n+m\log n)m 为操作次数。
    这样就绝对不会 MLE 了!

最后,再详细说一下新版本线段树的左右儿子关系。
当某个节点被复制,也就是分裂的时候,判断被修改的位置是在左子树还是右子树,我们管被修改的位置所处的子树叫做子树 A,另一棵子树叫做子树 B,那么这个分裂出来的新节点可以直接连向老版本线段树的子树 B 的根,这样子树 B 就共用了。
然后把子树 A 的根分裂,连边之后向下走一步。
这一步还是挺简单的。

四、建树

与经典的线段树原理一样,只不过我们要使用动态开点线段树。
此时,节点 a 的左右儿子不再是 2a2a+1 ,而是 ls_ars_a,分别表示 a 的左儿子和右儿子。
这样做的目的是,新版本的线段树因为有分裂操作,所以很难保证满足左右儿子为 2a2a+1 的条件。
具体就是每次递归时新建一个节点,然后把这个新结点作为左儿子或右儿子,这可以用返回值的方法来实现。

int build(int l,int r){
    int rt=++cnt;
    if(l==r){
        //此处进行基本的单点赋初始值
        return rt;
    }
    int mid=(l+r)>>1;
    ls(rt)=build(l,mid);
    rs(rt)=build(mid+1,r);
    //此处进行 pushup 操作
    return rt;
}

可以发现,此时传参就不用传节点编号了。

五、常规操作

可持久化数据结构最经典的操作就是在某个历史版本中访问并修改。
我们以可持久化数组为例,要求:

码风比较清奇,主要是主函数中调用的部分是重点。
此题中需要注意的是 N 要开到 3\times10^7 才能过。

六、区间 k 小值

到这里,我们再开始说真正的“主席树”,也就是可持久化权值线段树。
可持久化线段树 2 要求我们完成的操作是:

假设目前i=8,k=5a 数组的前 8 项为:

3,1,2,3,3,6,1,2

那么此时的 cnt 即为:

2,2,3,0,0,1

建立线段树:

图中节点上的数字代表这个节点的 d 值。
初始的时候我们在根节点。
发现 k 小于等于左儿子的 d 值,所以答案一定在左子树内。(这一步应该很容易理解吧)
向下走一步,得到下图:

这次我们发现 k 大于左儿子的 d,那就走到右子树。
重点!此时 k 需要减去左儿子的 d,也就是说现在的 k 改为 3
为什么呢?
现在就不得不提 query 的定义了。
query(l,r,k,p) 可以被定义为,可重集 S=\set{t\in\set{a_1,a_2,\cdots,a_i}|l\le t\le r} 的第 k 小值,[l,r]p 的管辖范围。
我非常喜欢数学语言,因为它比较严谨直观,且不容易发生歧义。
简便起见,我们记之为 Q(l,r,k)
\mu=\lfloor\dfrac{l+r}{2}\rfloor。也就是代码中的 mid
再设可重集 S_1=\set{t\in\set{a_1,a_2,\cdots,a_i}|l\le t\le \mu},S_2=\set{t\in\set{a_1, a_2,\cdots, a_i}|\mu+1\le t\le r}, 如果 k\le|S_1|,则 Q(l,r,k)=Q(l,\mu,k),否则 Q(l,r,k)=Q(\mu+1,r,k-|S_1|)
这样就变得易于理解多了!
因为 S_1 就是 S 中前 |S_1| 小的所有数,所以当 k>\mu 时,S_2 的第 k-|S_1| 小的值就是 S 中第 k 小的值啦!
综上所述,如果走右子树的话,k 需要减去左子树的 d 值。
于是向下走:

走完发现还在右子树(k>2), 于是 k 减去 2,得到 1

最后走到了值 3 所对应的叶子节点。
所以答案是 3

一个例子跑完,对权值线段树的理解加深了不少。
代码:

int query(int l,int r,int k,int p){
    int mid=(l+r)>>1,num=d[ls(p)];
    if(l==r) return l;
    if(k<=num) return query(l,mid,k,ls(p));
    else return query(mid+1,r,k-num,rs(p));
}

2. 主席树

现在考虑任意区间的 k 小值。
还是以 \{3,1,2,3,3,6,1,2\} 为例。
i3 的时候,cnt=\{1,1,1,0,0,0\}
i8 的时候,cnt=\{2,2,3,0,0,1\}
也就是说,i 在从 48 的过程中,cnt 的变化量就是 \Delta cnt=\{+1,+1,+2,+0,+0,+1\}
再回头一看,a_4a_8 中,不就是恰好有 111223,和 16 吗?
所以就是说,通过 rl-1 两个版本的 cnt 的相减,就可以得到下标区间 [l,r] 所对应的那个 \Delta cnt,也就可以得到从 a_la_r 每个数字出现几次了!
这是非常好的。现在只需要考虑用可持久化权值线段树来记录 cnt 的每一个历史版本就好了!
具体地,对于每次 i\gets i+1 的时候,建立一个新的版本,然后在新的版本上对 cnt_{a_i} 进行加一就可以了。

数学地说,我们定义节点 \lambda,\rho,为第 l-1 个版本以及第 r 个版本的线段树中,管辖范围为值域 [L,R] 的两个节点。
再定义 cnt_i^{(h)} 为第 h 个版本中的 cnt_ih 阶导?)
与之前对 d 的定义几乎不变。

那么设 $\xi(l,r,L,R)=d_{\rho}-d_{\lambda}=\displaystyle\sum_{i=l}^r cnt_i^{(r)}-cnt_{i}^{(l-1)}$。 和之前的方法一样,我们定义 $Q(l,r,k,p)$ 为:**可重集** $\color{red}{S=\set{t\in\set{a_l,a_{l+1},\cdots,a_{r-1},a_r}|L\le t\le R}}$ 的第 $k$ 小值,$[L,R]$ 是 $p$ 的管辖范围。 ~~这很合理。~~ 设 $\mu=\lfloor\dfrac{L+R}{2}\rfloor$。也就是代码中的 `mid`。 再设可重集 $S_1=\set{t\in\set{a_l,\cdots,a_r}|L\le t\le \mu},S_2=\set{t\in\set{a_l,\cdots, a_r}|\mu+1\le t\le R}$, 如果 $k\le|S_1|$,则 $Q(l,r,k,p)=Q(l,r,k,ls(p))$,否则 $Q(l,r,k,p)=Q(l,r,k-|S_1|,rs(p))$。 这样就比所谓“感性理解“看着好多了。 现在只需要算出 $|S_1|$ 就可以了! 它不就是 $\displaystyle\sum_{i=L}^\mu cnt_i^{(r)}-cnt_i^{(l-1)}$ 吗? 不就是 $\xi(l,r,L,\mu)$ 吗? 完美! 我们如复制粘贴般的就写好了 `query` : ```cpp int query(int l,int r,int s,int t,int k){ int mid=(s+t)>>1,xi=d[ls(r)]-d[ls(l)]; if(s==t) return s; if(k<=xi) return query(ls(l),ls(r),s,mid,k); else return query(rs(l),rs(r),mid+1,t,k-xi); } ``` 最后需要考虑的是求出 $d$ 数组。 这简单多了,根据可持久化线段树的原理,在 `query` 之前先把先把整个序列遍历一遍,每遍历到一个数就执行之前说过的单点修改(也就是加 $1$)即可。 ```cpp void update(int l,int r,int pos,int rt,int &p){ p=++cnt; ls(p)=ls(rt),rs(p)=rs(rt),d[p]=d[rt]+1; if(l==r) return ; int mid=(l+r)>>1; if(pos<=mid) update(l,mid,pos,ls(rt),ls(p)); else update(mid+1,r,pos,rs(rt),rs(p)); } ``` 值得一提的是,本题由于 $a_i$ 的值是 $10^9$ 级别,所以需要离散化。 ~~(总不能对着一个 `map` 建线段树吧)~~ 然后就通过了本题。 ```cpp #include<bits/stdc++.h> #define ls(p) st[p].ls #define rs(p) st[p].rs using namespace std; namespace Opshacom{ const int N=5e6+7; int n,m,a[N],tmp[N]; int len; class Chairman{ private:struct node{int ls,rs;}st[N]; public: int cnt,root[N],sum[N]; void build(int l,int r,int &p){ p=++cnt; if(l==r) return ; int mid=(l+r)>>1; build(l,mid,ls(p));build(mid+1,r,rs(p)); } void update(int l,int r,int pos,int rt,int &p){ p=++cnt; ls(p)=ls(rt),rs(p)=rs(rt),sum[p]=sum[rt]+1; if(l==r) return ; int mid=(l+r)>>1; if(pos<=mid) update(l,mid,pos,ls(rt),ls(p)); else update(mid+1,r,pos,rs(rt),rs(p)); } int query(int l,int r,int s,int t,int k){ int mid=(s+t)>>1,num=sum[ls(r)]-sum[ls(l)]; if(s==t) return s; if(k<=num) return query(ls(l),ls(r),s,mid,k); else return query(rs(l),rs(r),mid+1,t,k-num); } }tr; inline void Discretization(){ memcpy(tmp,a,sizeof(tmp)); sort(tmp+1,tmp+n+1); len=unique(tmp+1,tmp+n+1)-tmp-1; tr.build(1,len,tr.root[0]); for(int i=1;i<=n;i++) tr.update(1,len,lower_bound(tmp+1,tmp+len+1,a[i])-tmp,tr.root[i-1],tr.root[i]); } inline void work(){ cin>>n>>m; for(int i=1;i<=n;i++) cin>>a[i]; Discretization(); while(m--){ int l,r,k; cin>>l>>r>>k; cout<<tmp[tr.query(tr.root[l-1],tr.root[r],1,len,k)]<<"\n"; } } } int main(){ ios::sync_with_stdio(0); cin.tie(0), cout.tie(0); return Opshacom::work(),0; } ``` 代码其实挺短的,细节也不是很多? 其实这个还可以带修,套一个树状数组即可。因为『这是一篇(wǒ)面向(tài)普及选手(cài)的文章(le)』,所以不展开说。 ## 七、树上第 $k$ 小 例题:[Count on a Tree](https://www.luogu.com.cn/problem/P2633)。 其实就是求树上一条链上的第 $k$ 小值。 乍一看是树链剖分,其实还是主席树的板。 在序列上,我们从前往后遍历,依次单点修改,每次生成一个新的历史版本。 在树上,我们按 DFS 的顺序遍历,依次单点修改,每次生成一个新的历史版本。 在序列上,$\xi(l,r,L,R)=d_{\rho}-d_{\lambda}$, 在树上,$\xi(l,r,L,R)=d_{\rho}+d_{\lambda}-d_{v}-d_{u}$。 节点 $\rho,\lambda,v,u$ 的管辖区间都是值域 $[L,R]$,它们所处的线段树的版本分别是 $l,r,\operatorname{LCA}(l,r),fa(\operatorname{LCA}(l,r))$。 其实就是[树上差分](https://oi-wiki.org/basic/prefix-sum/#%E6%A0%91%E4%B8%8A%E5%B7%AE%E5%88%86)的原理! AC 代码:(竟然一遍过) ```cpp #include<bits/stdc++.h> using namespace std; int n,m;const int N=2e5+5; struct edge{int to,nxt;}e[N<<1]; int head[N],cntt; inline void add(int u,int v){ e[++cntt].to=v; e[cntt].nxt=head[u]; head[u]=cntt; } int cnt,sum[N<<5],ls[N<<5],rs[N<<5],a[N],tmp[N],len,root[N],fa[N],zx[N][32],dep[N]; void build(int l,int r,int &p){ p=++cnt; if(l==r) return ; int mid=(l+r)>>1; build(l,mid,ls[p]); build(mid+1,r,rs[p]); } void update(int id,int l,int r,int lst,int &p){ p=++cnt; ls[p]=ls[lst],rs[p]=rs[lst],sum[p]=sum[lst]+1; if(l==r) return ; int mid=(l+r)>>1; if(id<=mid) update(id,l,mid,ls[lst],ls[p]); else update(id,mid+1,r,rs[lst],rs[p]); } inline void Discretization(){ memcpy(tmp,a,sizeof(tmp)); sort(tmp+1,tmp+n+1); len=unique(tmp+1,tmp+n+1)-tmp-1; build(1,len,root[0]); } void dfs(int u){ update(lower_bound(tmp+1,tmp+len+1,a[u])-tmp,1,len,root[fa[u]],root[u]); for(int i=head[u];i;i=e[i].nxt){ int v=e[i].to; if(v==fa[u]) continue; fa[v]=u;dep[v]=dep[u]+1;dfs(v); } } int query(int l,int r,int lca,int flc,int s,int t,int k){ int num=sum[ls[l]]+sum[ls[r]]-sum[ls[lca]]-sum[ls[flc]]; if(s==t) return s; int mid=(s+t)>>1; if(k<=num) return query(ls[l],ls[r],ls[lca],ls[flc],s,mid,k); else return query(rs[l],rs[r],rs[lca],rs[flc],mid+1,t,k-num); } inline void init(){ for(int i=1;i<=n;i++) zx[i][0]=fa[i]; for(int j=1;j<=30;j++) for(int i=1;i<=n;i++) zx[i][j]=zx[zx[i][j-1]][j-1]; } int LCA(int u,int v){ if(u==v) return u; if(dep[u]>dep[v]) swap(u,v); for(int j=30;j>=0;j--) if(dep[u]<=dep[v]-(1<<j)) v=zx[v][j]; if(u==v) return u; for(int j=30;j>=0;j--){if(zx[v][j]!=zx[u][j]){v=zx[v][j];u=zx[u][j];}} return zx[u][0]; } int main(){ ios::sync_with_stdio(false); cin.tie(0), cout.tie(0); cin>>n>>m; for(int i=1;i<=n;i++) cin>>a[i]; Discretization(); for(int i=1;i<n;i++){ int u,v; cin>>u>>v; add(u,v); add(v,u); } dep[1]=1;dfs(1); init(); int lst=0; while(m--){ int l,r,k; cin>>l>>r>>k; l^=lst; int lca=LCA(l,r);int f=fa[lca]; lst=tmp[query(root[l],root[r],root[lca],root[f],1,len,k)]; cout<<lst<<"\n"; } return 0; } ``` 完结撒花!