『进阶』浅谈线段树

· · 算法·理论

一般线段树

初识线段树

一、问题产生

我们以本题为例,P3372 【模板】线段树 1。

二、问题分析

概括来说,本题题有两种操作:

本题的操作 1 为区间修改,操作 2 为区间求和。

三、寻找算法

1. 暴力(万物皆可暴力

算法说明:

使用数组直接存储。操作 1 使用循环一次给每一个数加上 x;操作 2 也使用循环依次求和。

好处:
坏处:

再看数据范围:1 \le n,m \le 10^5,TLE 无疑了。

2. 前缀和

算法说明:

使用前缀和数组存储。操作 1 需要将前缀和数组重新作处理;操作 2 可以直接算出。

好处:
坏处:

所以只能处理静态数组的区间问题的算法都无法通过本题(如 ST 表)。

3. 差分

使用差分数组存储。操作 1 可以直接修改左右端点;操作 2 需要重新还原

好处:
坏处:

依然无法通过本题。

4. 树状数组

本算法也可过,因主讲线段树,现暂时省略。

5. 线段树(本场主角)

既然本题是动态区间查询问题,我们要让修改和查询时间复杂度都要降低。很多人都想到了树形结构分治,我们将它们融合在一起,就成了线段树。

四、算法讲解

(1) 创建线段树

线段树既然是二叉树,就是二叉树的建树方法。

唯一不同的一点,就是这个结点的数据来源于其孩子节点的数据,有多添加一行计算代码。(可用函数封装)

代码如下:

struct asd{
    long long sum,tag;//sum:线段树数据(此为和);tag:懒标记(见后文区间修改)
};
vector<asd>t;//创建线段树数组
int n,m;
asd merge(asd a,asd b){//封装结点计算函数
    asd ret{a.sum+b.sum,0};
    return ret;
}
void build(const vector<int> &a,int k,int cl,int cr){//创建线段树。k:结点编号;cl:指定查找区间左端点;cr:指定查找区间右端点
    if(cl==cr){//如果此结点为叶结点
        t[k].sum=a[cl];//赋值
        return;
    }
    int lc=k<<1,rc=lc+1;//计算左孩子和右孩子的编号
    int mid=(cl+cr)>>1;//计算中点
    build(a,lc,cl,mid);//向左孩子递归
    build(a,rc,mid+1,cr);//向右孩子递归
    t[k]=merge(t[lc],t[rc]);//计算当前结点
}
void init(const vector<int> &a){//初始化
    n=a.size();
    t.resize(n<<2);//记得开四倍空间
    build(a,1,1,n);//创建线段树
}

(2) 区间查询

我们画个图来理解:

本图为 n=5 的线段树,我们先进行区间查询:

我们要求出 [1,4] 内所有元素的和。

步骤如下:

第一步:[1,4] 在结点 [1,5] 左孩子和右孩子均有跨度,所以继续递归进入两个子节点。

第二步:[1,4] 包含结点 [1,3],进行回溯;区间 [4,4] 在结点 [4,5] 的左孩子中,向左孩子进行递归。

第三步:[1,4] 结点包含 [4,4],进行回溯。

从中,我们知道要查找指定区间的和或最小值等数据,需要递归查找区间所包含且线段树拥有的子区间。(如图中的 [1,3][4,4]

代码如下:

long long qry(int k,int l,int r,int cl,int cr){//k:结点编号;l:此结点区间左端点;r:此结点区间右端点;cl:指定查找区间左端点;cr:指定查找区间右端点
    if(cl>r||cr<l){//如果此结点区间在指定查找区间之外
        return 0;//退出
    }
    if(cl>=l&&cr<=r){//如果指定查找区间完全包含此结点区间
        return t[k].sum;//返回子区间的数据
    }
    int lc=k<<1,rc=lc+1;//计算左孩子和右孩子的编号
    int mid=(cl+cr)>>1;//计算中点
    push_down(k,cl,cr);//懒标记向下传递(见后文区间修改)
    return qry(lc,l,r,cl,mid)+qry(rc,l,r,mid+1,cr);//回溯
}

(3) 区间修改

区间修改是最难想的一点,因为每次修改都要把区间的子区间都全部修改,时间复杂度仍然爆炸巨大,那怎么解决这个问题?

既然每次修改都要把区间的子区间都全部修改,那不如留个标记,等到要使用时再传到子区间并修改数据。

代码如下:

void push_down(int k,int cl,int cr){//下传懒标记
    int lc=k<<1,rc=lc+1;//计算左孩子和右孩子的编号
    int mid=(cl+cr)>>1;//计算中点
    t[lc].tag+=t[k].tag;//下传懒标记
    t[rc].tag+=t[k].tag;//下传懒标记
    t[lc].sum+=(mid-cl+1)*t[k].tag;//更新数据
    t[rc].sum+=(cr-mid)*t[k].tag;//更新数据
    t[k].tag=0;//删除原来的标记
}
void modify(int k,int l,int r,int cl,int cr,long long x){//区间修改
    if(cl>r||cr<l){//如果此结点区间在指定查找区间之外
        return;//退出
    }
    if(l<=cl&&r>=cr){//如果指定查找区间完全包含此结点区间
        t[k].tag+=x;//添加懒标记
        t[k].sum+=x*(cr-cl+1);//更新数据
        return;//退出
    }
    int lc=k<<1,rc=lc+1;//计算左孩子和右孩子的编号
    int mid=(cl+cr)>>1;//计算中点
    push_down(k,cl,cr);//下传懒标记
    modify(lc,l,r,cl,mid,x);//向左区间进行递归
    modify(rc,l,r,mid+1,cr,x);//向右区间进行递归
    t[k]=merge(t[lc],t[rc]);//计算更新后结点数据
}

完整代码如下:

# include<bits/stdc++.h>
using namespace std;
struct asd{
    long long sum,tag;
};
vector<asd>t;
int n,m;
asd merge(asd a,asd b){
    asd ret{a.sum+b.sum,0};
    return ret;
}
void push_down(int k,int cl,int cr){
    int lc=k<<1,rc=lc+1;
    int mid=(cl+cr)>>1;
    t[lc].tag+=t[k].tag;
    t[rc].tag+=t[k].tag;
    t[lc].sum+=(mid-cl+1)*t[k].tag;
    t[rc].sum+=(cr-mid)*t[k].tag;
    t[k].tag=0;
}
void build(const vector<int> &a,int k,int cl,int cr){
    if(cl==cr){
        t[k].sum=a[cl];
        return;
    }
    int lc=k<<1,rc=lc+1;
    int mid=(cl+cr)>>1;
    build(a,lc,cl,mid);
    build(a,rc,mid+1,cr);
    t[k]=merge(t[lc],t[rc]);
}
void init(const vector<int> &a){
    n=a.size();
    t.resize(n<<2);
    build(a,1,1,n);
}
long long qry(int k,int l,int r,int cl,int cr){
    if(cl>r||cr<l){
        return 0;
    }
    if(cl>=l&&cr<=r){
        return t[k].sum;
    }
    int lc=k<<1,rc=lc+1;
    int mid=(cl+cr)>>1;
    push_down(k,cl,cr);
    return qry(lc,l,r,cl,mid)+qry(rc,l,r,mid+1,cr);
}
void modify(int k,int l,int r,int cl,int cr,long long x){
    if(cl>r||cr<l){
        return;
    }
    if(l<=cl&&r>=cr){
        t[k].tag+=x;
        t[k].sum+=x*(cr-cl+1);
        return;
    }
    int lc=k<<1,rc=lc+1;
    int mid=(cl+cr)>>1;
    push_down(k,cl,cr);
    modify(lc,l,r,cl,mid,x);
    modify(rc,l,r,mid+1,cr,x);
    t[k]=merge(t[lc],t[rc]);
}

五、算法分析

好处:

坏处

六、算法答疑

1. 为什么要开 4n 的大小?

#### 2. 线段树具体主要需要改那些地方才能改变其作用? 具体有 merge 函数: ```cpp asd merge(asd a,asd b){ asd ret{a.sum+b.sum,0};//<- return ret; } ``` push\_down 函数: ```cpp void push_down(int k,int cl,int cr){ int lc=k<<1,rc=lc+1; int mid=(cl+cr)>>1; t[lc].tag+=t[k].tag;//<- t[rc].tag+=t[k].tag;//<- t[lc].sum+=(mid-cl+1)*t[k].tag;//<- t[rc].sum+=(cr-mid)*t[k].tag;//<- t[k].tag=0; } ``` 添加懒标记部分: ``` void modify(int k,int l,int r,int cl,int cr,long long x){ if(cl>r||cr<l){ return; } if(l<=cl&&r>=cr){ t[k].tag+=x;//<- t[k].sum+=x*(cr-cl+1);//<- return; } int lc=k<<1,rc=lc+1; int mid=(cl+cr)>>1; push_down(k,cl,cr); modify(lc,l,r,cl,mid,x); modify(rc,l,r,mid+1,cr,x); t[k]=merge(t[lc],t[rc]); } ``` qry 函数的返回值和计算部分: ```cpp long long qry(int k,int l,int r,int cl,int cr){ if(cl>r||cr<l){ return 0;//<- } if(cl>=l&&cr<=r){ return t[k].sum; } int lc=k<<1,rc=lc+1; int mid=(cl+cr)>>1; push_down(k,cl,cr); return qry(lc,l,r,cl,mid)+qry(rc,l,r,mid+1,cr);//<- } ``` 剩下部分**几乎**无改动。 #### 3. 线段树适用于那些题? 线段树适用于动态数组的单点 or 区间修改、单点 or 区间查询。 ### 七、推荐题目: - #### [P3372 【模板】线段树 1](https://www.luogu.com.cn/problem/P3372) - #### [P1531 I Hate It](https://www.luogu.com.cn/problem/P1531) - #### [P2068 统计和](https://www.luogu.com.cn/problem/P2068) ## 一般线段树扩展知识 ### 一、问题产生 我们以本题为例,[P1253 扶苏的问题](https://www.luogu.com.cn/problem/P1253)。 ### 二、问题分析 概括来说,本题题有三种操作: - `1 l r x`:将每个 $a_i$ 修改为 $x$,且 $l \le i \le r$。 - `2 l r x`:将每个 $a_i$ 增加 $x$,且 $l \le i \le r$。 - `3 l r`:求出 $[a_l,a_r]$ 之间的最大值。 本题的操作 1 和操作 2 为区间修改,操作 3 为区间最大值。 ### 三、算法讲解 本题与上题一样,但是有两个区间修改的操作: - 操作 1 为区间推平。 - 操作 2 为区间增加。 懒标记是一个变量,不能同时执行两个操作,那就直接多加一个懒标记维护区间修改。 那区间推平怎么写呢? 我们来看一下。 ![](https://cdn.luogu.com.cn/upload/image_hosting/akbtsgjw.png) 推平是将元素赋值,可只用懒标记标记推平的数值。 注意事项见代码。 代码如下: ```cpp asd merge(asd a,asd b){ asd ret{max(a.maxx,b.maxx),0x3f3f3f3f};//设置tag为一个最大值标记,防止推平函数将元素不小心推成0 return ret; } void push_down(int k,int cl,int cr){ int lc=k<<1,rc=lc+1; int mid=(cl+cr)>>1; if(t[k].tag==0x3f3f3f3f){//判断标记 return; } t[lc].tag=t[k].tag; t[rc].tag=t[k].tag; t[lc].maxx=t[k].tag;//推平后所有元素为同一数值,所以最大值为推平数值 t[rc].maxx=t[k].tag; t[k].tag=0x3f3f3f3f;//标记 } void init(const vector<int> &a){ n=a.size(); t.resize(n<<2); for(int i=0;i<t.size();i++){ t[i].tag=-0x3f3f3f3f;//初始化标记 t[i].maxx=-0x3f3f3f3f;//设置最小值 } build(a,1,0,n-1); } long long qry(int k,int l,int r,int cl,int cr){ if(cl>r||cr<l){ return -0x3f3f3f3f;//返回最小值以便取最大值 } if(cl>=l&&cr<=r){ return t[k].maxx; } int lc=k<<1,rc=lc+1; int mid=(cl+cr)>>1; push_down(k,cl,cr); return max(qry(lc,l,r,cl,mid),qry(rc,l,r,mid+1,cr));//返回最大值 } void modify(int k,int l,int r,int cl,int cr,int x){ if(cl>r||cr<l){ return; } if(l<=cl&&r>=cr){ t[k].tag=x; t[k].maxx=x;//推平后所有元素为同一数值,所以最大值为推平数值,同理 return; } int lc=k<<1,rc=lc+1; int mid=(cl+cr)>>1; push_down(k,cl,cr); modify(lc,l,r,cl,mid,x); modify(rc,l,r,mid+1,cr,x); t[k]=merge(t[lc],t[rc]); } ``` 推平清楚了,就可以做本题了。 本题要注意两个懒标记如何兼容。 代码如下: ```cpp #include<bits/stdc++.h> using namespace std; struct asd{ long long maxx,tag,tag2; }; vector<asd>t; int n,m; asd merge(asd a,asd b){ asd ret{max(a.maxx,b.maxx),0,-0x3f3f3f3f3f3f3f}; return ret; } void push_down(int k,int cl,int cr){ int lc=k<<1,rc=lc+1; int mid=(cl+cr)>>1; if(t[k].tag2!=-0x3f3f3f3f3f3f3f){//如果此区间被推平了 t[lc].tag2=t[k].tag2;//直接推平 t[rc].tag2=t[k].tag2; t[lc].maxx=t[k].tag2; t[rc].maxx=t[k].tag2; }else{//如果没有被推平 if(t[lc].tag2!=-0x3f3f3f3f3f3f3f){//如果子区间被推平了 t[lc].tag2+=t[k].tag;//将加和了加到推平中 }else{ t[lc].tag+=t[k].tag;//否则直接加 } if(t[rc].tag2!=-0x3f3f3f3f3f3f3f){//同理 t[rc].tag2+=t[k].tag; }else{ t[rc].tag+=t[k].tag; } t[lc].maxx+=t[k].tag; t[rc].maxx+=t[k].tag; } t[k].tag=0; t[k].tag2=-0x3f3f3f3f3f3f3f; } void build(const vector<int> &a,int k,int cl,int cr){ if(cl==cr){ t[k].maxx=a[cl]; return; } int lc=k<<1,rc=lc+1; int mid=(cl+cr)>>1; build(a,lc,cl,mid); build(a,rc,mid+1,cr); t[k]=merge(t[lc],t[rc]); } void init(const vector<int> &a){ n=a.size(); t.resize(n<<2); for(int i=0;i<t.size();i++){ t[i].tag2=-0x3f3f3f3f3f3f3f; } build(a,1,1,n); } long long qry(int k,int l,int r,int cl,int cr){ if(cl>r||cr<l){ return -1e15; } if(cl>=l&&cr<=r){ return t[k].maxx; } int lc=k<<1,rc=lc+1; int mid=(cl+cr)>>1; push_down(k,cl,cr); return max(qry(lc,l,r,cl,mid),qry(rc,l,r,mid+1,cr)); } void modify(int k,int l,int r,int cl,int cr,long long x){ if(cl>r||cr<l){ return; } if(l<=cl&&r>=cr){ t[k].tag=0;//推平要把加和覆盖 t[k].tag2=x; t[k].maxx=x; return; } int lc=k<<1,rc=lc+1; int mid=(cl+cr)>>1; push_down(k,cl,cr); modify(lc,l,r,cl,mid,x); modify(rc,l,r,mid+1,cr,x); t[k]=merge(t[lc],t[rc]); } void modify2(int k,int l,int r,int cl,int cr,long long x){ if(cl>r||cr<l){ return; } if(l<=cl&&r>=cr){ if(t[k].tag2!=-0x3f3f3f3f3f3f3f){//如果此区间有推平标记 t[k].tag2+=x;//直接累加 t[k].maxx+=x;//最大值也累加 }else{ t[k].tag+=x;//累加到加和标记上 t[k].maxx+=x; } return; } int lc=k<<1,rc=lc+1; int mid=(cl+cr)>>1; push_down(k,cl,cr); modify2(lc,l,r,cl,mid,x); modify2(rc,l,r,mid+1,cr,x); t[k]=merge(t[lc],t[rc]); } ``` ## 四、推荐题目 - ### [P1253 扶苏的问题](https://www.luogu.com.cn/problem/P1253) - ### [P3373 【模板】线段树 2](https://www.luogu.com.cn/problem/P3373) ## 总结 ‌线段树‌是一种用于处理区间查询和更新操作的数据结构,特别适用于处理大量数据的区间查询问题。线段树通过将区间递归分解为更小的区间,并存储每个区间的信息,从而高效地解决区间查询和更新问题。 线段树是一种二叉搜索树,每个节点代表一个区间,存储该区间的信息。 线段树要注意懒标记的标记和传播、查找区间的递归判断。因为码量巨大,需要牢记,防止细节错误,修改很浪费时间。(我有过) 我认为能用树状数组用树状数组,线段树太容易出错了,而且空间占用很大。 # 动态开点线段树 ### 一、提出普通线段树的不足 前面讲到普通线段树需要开 4 倍的数组。但空间复杂度实在太大了,$[0,10^9]$ 肯定空间会超出。 ### 二、改进算法 为了节省空间,我们可以不用直接将树建好,而是在最初只建立一个根结点代表整个区间。只有我们需要访问某个没有定义的子区间时,才建立代表这个区间的结点。这样我们不再使用 2p 和 2p+1 代表 p 结点的儿子,而是用 $\text{l}$ 和 $\text{r}$ 记录儿子的编号。 ### 三、代码演绎 #### 1. 区间 pushdown 动态开点线段树的 `push_down` 要在一般线段树的 `push_down` 中加上开点操作。 ```cpp void push_down(int u,int len){//u:当前区间编号;len:区间长度 if(!lc[u]){//如果左孩子未被创建 lc[u]=++cnt;//给编号为u的结点左孩子编号 t[lc[u]].len=len-len/2;//计算区间长度 } if(!rc[u]){//同理 rc[u]=++cnt; t[rc[u]].len=len/2; } /* 懒标记操作(见前文) */ } ``` #### 2. 区间查询 区间查询依旧不变,只需要改一改形参,不用额外添加开点,因为结点要么已经创建,要么在 `push_down` 函数中创建。 ```cpp long long qry(int u,int l,int r,int L,int R){//u:当前区间编号;l,r:当前区间左右端点;L,R:要查询的区间的左右端点 if(L<=l&&r<=R){ return tr[u].v; } if(r<L||l>R){ return 0; } int mid=(l+r)>>1; pushdown(u,r-l+1); return qry(lc[u],l,mid,L,R)+qry(rc[u],mid+1,r,L,R); } ``` #### 3. 除了形参几乎无改动。 ```cpp void modify(int u,int l,int r,int L,int R,int d){//u:当前区间编号;l,r:当前区间左右端点;L,R:要修改的区间;d:修改(不一定是推平)的数值 if(L<=l&&r<=R){ t[u].tag+=d; t[u].sum+=(r-l+1LL)*d; } if(r<L||l>R){ return; } pushdown(u,r-l+1); int mid=l+r>>1; modify(lc[u],l,mid,L,R,d); modify(rc[u],mid+1,r,L,R,d); t[k]=merge(t[lc],t[rc]); } ``` 致谢 [acwing](https://www.acwing.com/blog/content/31417/) 的思路和部分代码(因为本人还没学,现在上面看了思路) # 篇尾 感谢您的观看,能不能留下你的点赞和关注?