浅谈线段树

Kevin_Wa

2018-11-25 09:27:52

题解

说起线段树大家都不陌生。那么不会线段数的一起进入主题吧。

现在我给你一道题。

1、一共有n个数,给你m个范围,让你求范围内的数和。(n \leq 10^6

这不是很简单,不是小学生的OI题吗?前缀和,搞定。

再加强:

2、一共有n个数,给你k个操作,让你在lr区间内加一个数,或让你求区间内的和。(n \leq 10^4

emm前缀和无力维持了。

再加强:

2、一共有n个数,给你k个操作,让你在lr区间内加一个数,或让你求区间内的和。(n \leq 10^6

这需要一种强大的数据结构线段树

下是简图(度娘上找到的)。

线段树,类似区间树,它在各个节点保存一条线段(数组中的一段子数组),主要用于高效解决连续区间的动态查询问题,

由于二叉结构的特性,它基本能保持每个操作的复杂度为O(log n)。

线段树的每个节点表示一个区间,

子节点则分别表示父节点的左右半区间,例如父亲的区间是[a,b],那么(c=(a+b)/2)左儿子的区间是[a,c],右儿子的区间是[c+1,b]。

由上图可得,

1、每个节点的左孩子区间范围为[lmid],右孩子为[mid+1,r]

2、对于结点k,左孩子结点为2k,右孩子为2k+1,这符合完全二叉树的性质。

线段树支持五种种操作:建树、单点查询、单点修改、区间查询、区间修改。

在其他延申里还支持如区间求最值。

1、建树

a、对于二分到的每一个结点,给它的左右端点确定范围。

b、如果是叶子节点,存储要维护的信息,再回到父节点时累计到父节点去。

c、合并

代码

void build(int k,int t,int w)
{ int mid;
    if (t>w) return;
    if (t==w)
      {
        tree[k].l=t;tree[k].r=w;
        tree[k].w=a[t];
        return;
      }
    mid=(t+w)/2;
    build(k*2,t,mid);
    build(k*2+1,mid+1,w);
    tree[k].l=t;tree[k].r=w;
    tree[k].w=tree[k*2].w+tree[k*2+1].w;
}

单点查询

查询一个点的状态,设待查询点为x

a、如果当前枚举的点左右端点相等,即叶子节点,就是目标节点。

b、如果不是,所以设查询位置为x,当前结点区间范围为了l,r,中点为mid。

c、如果x<=mid,则递归它的左孩子,否则递归它的右孩子

代码

void ask(int k)
{
    if(tree[k].l==tree[k].r) //当前结点的左右端点相等,是叶子节点,是最终答案 
    {
        ans=tree[k].w;
        return ;
    }
    int m=(tree[k].l+tree[k].r)/2;
    if(x<=m) ask(k*2);//目标位置比中点靠左,就递归左孩子 
    else ask(k*2+1);//反之,递归右孩子 
}

单点修改

即更改某一个点的状态。

结合单点查询的原理,找到x的位置;根据建树状态合并的原理,修改每个结点的状态。

void add(int k)
{
    if(tree[k].l==tree[k].r)//找到目标位置 
    {
        tree[k].w+=y;
        return;
    }
    int m=(tree[k].l+tree[k].r)/2;
    if(x<=m) add(k*2);
    else add(k*2+1);
    tree[k].w=tree[k*2].w+tree[k*2+1].w;//所有包含结点k的结点状态更新 
}

区间查询

代码

void sum(int k)
{
    if(tree[k].l>=x&&tree[k].r<=y) 
    {
        ans+=tree[k].w;
        return;
    }
    int m=(tree[k].l+tree[k].r)/2;
    if(x<=m) sum(k*2);
    if(y>m) sum(k*2+1);
}

区间修改

同理

我们不要递归到每个节点。所以要有一个新的概念:懒标记。

就像新年的时候的压岁钱,只有要用的时候才用,不要的直接给父母保管。

所以,传下来的更改值若在一个区间里,就不再下传,修改完该节点信息后,在此节的懒标记上打一个更改值。

当需要递归这个节点的子节点时,标记下传给子节点。这里不必管用哪个子节点,两个都传下去。

①当前节点的懒标记累积到子节点的懒标记中。

②修改子节点状态。在引例中,就是原状态+子节点区间点的个数父节点传下来的懒标记。

③父节点懒标记清0。这个懒标记已经传下去了,欠债还清,不用再还了。

下传代码

void pushdown(int k)
{
    tree[k*2].w+=((tree[k*2].r-tree[k*2].l+1)*tree[k].f);
    tree[k*2+1].w+=((tree[k*2+1].r-tree[k*2+1].l+1)*tree[k].f);
    tree[k*2].f+=tree[k].f;
    tree[k*2+1].f+=tree[k].f;
    tree[k].f=0;
}

区间修改代码

void add(int k,int t,int w)
{ int mid;
if (t>w) return;
    if (x<=t&&w<=y) 
      {
        tree[k].w+=((w-t+1)*z);
        tree[k].f+=z;
        return ;
      }
    mid=(t+w)/2;
    if (tree[k].f) pushdown(k); 
    if (x<=mid) add(k*2,t,mid);
    if (y>mid) add(k*2+1,mid+1,w);
    tree[k].w=tree[k*2].w+tree[k*2+1].w;
}

单点查询代码

 void ask(int k)//单点查询
{
    if(tree[k].l==tree[k].r)
    {
        ans=tree[k].w;
        return ;
    }
    if(tree[k].f) pushdown(k);//懒标记下传,唯一需要更改的地方
    int m=(tree[k].l+tree[k].r)/2;
    if(x<=m) ask(k*2);
    else ask(k*2+1);
}

区间查询代码

int ask(int k,int t,int w)
{ int mid;
if (t>w) return 0;
    if (x<=t&&w<=y)
      {
        return tree[k].w;
      }
    mid=(t+w)/2;
    if (tree[k].f) pushdown(k); 
    int sum=0;
    if (x<=mid)sum+=ask(k*2,t,mid);
    if (y>mid)sum+=ask(k*2+1,mid+1,w);
    tree[k].w=tree[k*2].w+tree[k*2+1].w;
    return sum;
}

总结

#include<cstdio>
using namespace std;
int n,p,a,b,m,x,y,ans;
struct node
{
    int l,r,w,f;
}tree[400001];
inline void build(int k,int ll,int rr)//建树 
{
    tree[k].l=ll,tree[k].r=rr;
    if(tree[k].l==tree[k].r)
    {
        scanf("%d",&tree[k].w);
        return;
    }
    int m=(ll+rr)/2;
    build(k*2,ll,m);
    build(k*2+1,m+1,rr);
    tree[k].w=tree[k*2].w+tree[k*2+1].w;
}
inline void down(int k)//标记下传 
{
    tree[k*2].f+=tree[k].f;
    tree[k*2+1].f+=tree[k].f;
    tree[k*2].w+=tree[k].f*(tree[k*2].r-tree[k*2].l+1);
    tree[k*2+1].w+=tree[k].f*(tree[k*2+1].r-tree[k*2+1].l+1);
    tree[k].f=0;
}
inline void ask_point(int k)//单点查询
{
    if(tree[k].l==tree[k].r)
    {
        ans=tree[k].w;
        return ;
    }
    if(tree[k].f) down(k);
    int m=(tree[k].l+tree[k].r)/2;
    if(x<=m) ask_point(k*2);
    else ask_point(k*2+1);
}
inline void change_point(int k)//单点修改 
{
    if(tree[k].l==tree[k].r)
    {
        tree[k].w+=y;
        return;
    }
    if(tree[k].f) down(k);
    int m=(tree[k].l+tree[k].r)/2;
    if(x<=m) change_point(k*2);
    else change_point(k*2+1);
    tree[k].w=tree[k*2].w+tree[k*2+1].w; 
}
inline void ask_interval(int k)//区间查询 
{
    if(tree[k].l>=a&&tree[k].r<=b) 
    {
        ans+=tree[k].w;
        return;
    }
    if(tree[k].f) down(k);
    int m=(tree[k].l+tree[k].r)/2;
    if(a<=m) ask_interval(k*2);
    if(b>m) ask_interval(k*2+1);
}
inline void change_interval(int k)//区间修改 
{
    if(tree[k].l>=a&&tree[k].r<=b)
    {
        tree[k].w+=(tree[k].r-tree[k].l+1)*y;
        tree[k].f+=y;
        return;
    }
    if(tree[k].f) down(k);
    int m=(tree[k].l+tree[k].r)/2;
    if(a<=m) change_interval(k*2);
    if(b>m) change_interval(k*2+1);
    tree[k].w=tree[k*2].w+tree[k*2+1].w;
}
int main()
{
    scanf("%d",&n);//n个节点 
    build(1,1,n);//建树 
    scanf("%d",&m);//m种操作 
    for(int i=1;i<=m;i++)
    {
        scanf("%d",&p);
        ans=0;
        if(p==1)
        {
            scanf("%d",&x);
            ask_point(1);//单点查询,输出第x个数 
            printf("%d",ans);
        } 
        else if(p==2)
        {
            scanf("%d%d",&x,&y);
            change_point(1);//单点修改 
        }
        else if(p==3)
        {
            scanf("%d%d",&a,&b);//区间查询 
            ask_interval(1);
            printf("%d\n",ans);
        }
        else
        {
             scanf("%d%d%d",&a,&b,&y);//区间修改 
             change_interval(1);
        }
    }
}

例题

P3372 【模板】线段树 1

模板题

#include<bits/stdc++.h>
#define N 100010
#define int long long
using namespace std;
struct node{
    int l,r,w,f;
}tree[N*4];
int x,y,z,i,a[N],n,m;
void build(int k,int t,int w)
{ int mid;
    if (t>w) return;
    if (t==w)
      {
        tree[k].l=t;tree[k].r=w;
        tree[k].w=a[t];
        return;
      }
    mid=(t+w)/2;
    build(k*2,t,mid);
    build(k*2+1,mid+1,w);
    tree[k].l=t;tree[k].r=w;
    tree[k].w=tree[k*2].w+tree[k*2+1].w;
}
void pushdown(int k)
{
    tree[k*2].w+=((tree[k*2].r-tree[k*2].l+1)*tree[k].f);
    tree[k*2+1].w+=((tree[k*2+1].r-tree[k*2+1].l+1)*tree[k].f);
    tree[k*2].f+=tree[k].f;
    tree[k*2+1].f+=tree[k].f;
    tree[k].f=0;
}
void add(int k,int t,int w)
{ int mid;
if (t>w) return;
    if (x<=t&&w<=y) 
      {
        tree[k].w+=((w-t+1)*z);
        tree[k].f+=z;
        return ;
      }
    mid=(t+w)/2;
    if (tree[k].f) pushdown(k); 
    if (x<=mid) add(k*2,t,mid);
    if (y>mid) add(k*2+1,mid+1,w);
    tree[k].w=tree[k*2].w+tree[k*2+1].w;
}
int ask(int k,int t,int w)
{ int mid;
if (t>w) return 0;
    if (x<=t&&w<=y)
      {
        return tree[k].w;
      }
    mid=(t+w)/2;
    if (tree[k].f) pushdown(k); 
    int sum=0;
    if (x<=mid)sum+=ask(k*2,t,mid);
    if (y>mid)sum+=ask(k*2+1,mid+1,w);
    tree[k].w=tree[k*2].w+tree[k*2+1].w;
    return sum;
}
int read(int &x)
{
    char c=getchar();int f=1;
    x=0;
    while (c<'0'||c>'9')
      {
      if (c=='-') f=-1;
      c=getchar();
      }
    while (c>='0'&&c<='9')
      {
        x=x*10+(int)c-48;
        c=getchar();
      }
    return x*f;
}
signed main()
{
    //freopen(".in","r",stdin);
    //freopen(".out","w",stdout);
    n=read(n);m=read(m);
    memset(a,0,sizeof(a));
    for (int i=1;i<=n;i++) a[i]=read(a[i]);
    build(1,1,n);
    for (int j=1;j<=m;j++)
      { int c;
        c=read(c);x=read(x);y=read(y);
        if (c==1) 
          {
            z=read(z);
            add(1,1,n);
            }
        else printf("%lld\n",ask(1,1,n));
      }

}

P3373 【模板】线段树 2

套了乘法,注意清1。

#include <bits/stdc++.h>//区间
#define N 400001
#define ll long long
using namespace std;

struct node{ll mul,add,val;} st[N];
ll n,m,k,c,x,y,z;
ll a[N];

void build(ll x,ll l,ll r)
{
    st[x].mul=1;st[x].add=0;
    if (l==r) {st[x].val=a[l];return;}
    ll mid=l+r>>1;
    build(x<<1,l,mid);
    build(x<<1|1,mid+1,r);
    st[x].val=(st[x<<1].val+st[x<<1|1].val)%k;
}

void pushdown(ll x,ll l,ll r)
{
    ll mid=l+r>>1;
    st[x<<1].val=(st[x<<1].val*st[x].mul+st[x].add*(mid-l+1))%k;
    st[x<<1].mul=(st[x<<1].mul*st[x].mul)%k;
    st[x<<1].add=(st[x<<1].add*st[x].mul+st[x].add)%k;
    st[x<<1|1].val=(st[x<<1|1].val*st[x].mul+st[x].add*(r-mid))%k;
    st[x<<1|1].mul=(st[x<<1|1].mul*st[x].mul)%k;
    st[x<<1|1].add=(st[x<<1|1].add*st[x].mul+st[x].add)%k;
    st[x].mul=1;st[x].add=0;
}

void update_mul(ll x,ll l,ll r,ll ql,ll qr,ll mul)
{
    if (l>qr||ql>r) return;
    if (l>=ql&&qr>=r)
        {
            st[x].val=(st[x].val*mul)%k;
            st[x].mul=(st[x].mul*mul)%k;
            st[x].add=(st[x].add*mul)%k;
            return;
        }
    ll mid=l+r>>1;
    pushdown(x,l,r);
    update_mul(x<<1,l,mid,ql,qr,mul);
    update_mul(x<<1|1,mid+1,r,ql,qr,mul);
    st[x].val=(st[x<<1].val+st[x<<1|1].val)%k;
}

void update_add(ll x,ll l,ll r,ll ql,ll qr,ll add)
{
    if (l>qr||ql>r) return;
    if (l>=ql&&qr>=r)
        {
            st[x].val=(st[x].val+add*(r-l+1))%k;
            st[x].add=(st[x].add+add)%k;
            return;
        }
    ll mid=l+r>>1;
    pushdown(x,l,r);
    update_add(x<<1,l,mid,ql,qr,add);
    update_add(x<<1|1,mid+1,r,ql,qr,add);
    st[x].val=(st[x<<1].val+st[x<<1|1].val)%k;
}

ll query(ll x,ll l,ll r,ll ql,ll qr)
{
    if (l>qr||ql>r) return 0;
    if (l>=ql&&qr>=r) return st[x].val;
    ll mid=l+r>>1;
    pushdown(x,l,r);
    return (query(x<<1,l,mid,ql,qr)+query(x<<1|1,mid+1,r,ql,qr))%k;
}

int main()
{
    scanf("%lld%lld%lld",&n,&m,&k);
    for (int i=1;i<=n;i++) scanf("%lld",&a[i]);
    build(1,1,n);
    for (int i=1;i<=m;i++)
        {
            scanf("%lld%lld%lld",&c,&x,&y);
            if (c==1) scanf("%lld",&z),update_mul(1,1,n,x,y,z);
            if (c==2) scanf("%lld",&z),update_add(1,1,n,x,y,z);
            if (c==3) printf("%lld\n",query(1,1,n,x,y));
        }
    return 0;
}

P1440 求m区间内的最小值

可能线段树不是最优,但要知道线段也支持这个操作

#include<bits/stdc++.h>
#define N 2000010
using namespace std;
struct node{
    int l,r,w,f;
}tree[N*2+1];
int x,y,z,i,a[N],n,m;
void build(int k,int t,int w)
{ int mid;
    if (t>w) return;
    if (t==w)
      {
        tree[k].l=t;tree[k].r=w;
        tree[k].w=a[t];
        return;
      }
    mid=(t+w)/2;
    build(k*2,t,mid);
    build(k*2+1,mid+1,w);
    tree[k].l=t;tree[k].r=w;
    tree[k].w=min(tree[k*2].w,tree[k*2+1].w);
}
int ask(int k,int t,int w)
{ int mid;
if (t>w) return 0;
    if (x<=t&&w<=y)
      {
        return tree[k].w;
      }
    mid=(t+w)/2;
    int sum=INT_MAX;
    if (x<=mid)sum=min(sum,ask(k*2,t,mid));
    if (y>mid)sum=min(sum,ask(k*2+1,mid+1,w));
    return sum;
}
int read(int &x)
{
    char c=getchar();int f=1;
    x=0;
    while (c<'0'||c>'9')
      {
      if (c=='-') f=-1;
      c=getchar();
      }
    while (c>='0'&&c<='9')
      {
        x=x*10+(int)c-48;
        c=getchar();
      }
    return x*f;
}
signed main()
{
    //freopen(".in","r",stdin);
    //freopen(".out","w",stdout);
    n=read(n);m=read(m);
    memset(a,INT_MAX,sizeof(a));
    for (int i=1;i<=n;i++) a[i]=read(a[i]);
    build(1,1,n);
    for (int i=1;i<=n;i++)
      { int c;
        x=i-m;y=i-1;
        if (x<=0) x=1;
        if (x>y) 
          {
            printf("0\n");
            continue;
            }
        printf("%d\n",ask(1,1,n));
      }

}