题解 P3369 【【模板】普通平衡树(Treap/SBT)】

远航之曲

2017-04-24 07:37:30

题解

这里给一种非旋转Treap的详细教程

建议有了splay的基础再看这个教程

反对楼下说非旋转Treap慢,实地测试比splay快很多

想学更多Treap的姿势欢迎来博客

首先Orz fhq %%%%%

何为Treap

Treap是一种二叉平衡树。Treap=BST+Heap

Treap比较好玩的一点是,它整体是拥有二叉搜索树的性质,但是它的每一个节点都会有一个附加权值,它的附加权值是符合堆的性质。

这样我们可以明显看出,这棵树的形态(平衡与否)是由这个附加权值决定的。

那我们该如何取这个权值呢?随机!!

是的,我们每次都随机一个权值作为它的附加权值,那么它就*大概*是平衡的。

但是一般的平衡树都会有很恼人的旋转,又不好写,又不好调。

这里就郑重推出:fhq treap!

只需要两个基础操作,就可以达到splay的所有功能

1: ¥split¥

它的主要思想就是把一个Treap分成两个。

split操作有两种类型,一种是按照权值分配,一种是按前k个分配。

第一种就是把所有小于k的权值的节点分到一棵树中,第二种是把前k个分到一个树里。

权值版:

void split(int now,int k,int &x,int &y)
{
    if (!now) x=y=0;
    else
    {
        if (val[now]<=k)
            x=now,split(ch[now][1],k,ch[now][1],y);
        else
            y=now,split(ch[now][0],k,x,ch[now][0]);
        update(now);
    }
}

对于我们遍历到每一个点,假如它的权值小于k,那么它的所有左子树,都要分到左边的树里,然后遍历它的右儿子。假如大于k,把它的所有右子树分到右边的树里,遍历左儿子。

因为它的最多操作次数就是一直分到底,效率就是O(\log n)

对于前k个版的,就是像找第k大的感觉。每次减掉siz

void split(int now,int k,int &x,int &y)
{
    if (!now) x=y=0;
    else
    {
        if (k<=siz[ch[now][0]])
            y=now,split(ch[now][0],k,x,ch[now][0]);
        else
            x=now,split(ch[now][1],k-siz[ch[now][0]]-1,ch[now][1],y);
        update(now);
    }
}

2: merge

这个就是把两个Treap合成一个,保证第一个的权值小于第二个。

因为第一个Treap的权值都比较小,我们比较一下它的tar(附加权值),假如第一个的tar小,我们就可以直接保留它的所有左子树,接着把第一个Treap变成它的右儿子。反之,我们可以保留第二棵的所有右子树,指针指向左儿子。

你可以把这个过程形象的理解为在第一个Treap的左子树上插入第二个树,也可以理解为在第二个树的左子树上插入第一棵树。因为第一棵树都满足小于第二个树,所以就变成了比较tar来确定树的形态。

也就是说,我们其实是遍历了第一个trep的根->最大节点,第二个Treap的根->最小节点,也就是O(\log n)

代码:

int merge(int x,int y)
{
    if (!x || !y) return x+y;
    if (pri[x]<pri[y])
    {
        ch[x][1]=merge(ch[x][1],y);
        update(x);
        return x;
    }
    else
    {
        ch[y][0]=merge(x,ch[y][0]);
        update(y);
        return y;
    }
}

下面我们就可以通过这两个基本的东西实现各种各样的操作了。

3:insert

插入一个权值为v的点,把树按照v的权值split成两个,再按照顺序merge回去。

split(root,v,x,y);
root=merge(merge(x,new_node(v)),y);

4: del

删除权值为v的点,把树按照v分成两个a,b,再把a按照v-1分成c,d。把c的两个子儿子merge起来,再merge(merge(c,d),b)

split(root,v,x,z);
split(x,v-1,x,y);
y=merge(ch[y][0],ch[y][1]);
root=merge(merge(x,y),z);

5: precursor

找前驱的话把root按v-1 split成x,y,在x里面找最大值

6: successor

找后继的话把root按v split成x,y,在y里找最小值

7: rank

把root按v-1 split成x,y,排名是x的siz

代码

#include <cstdio>
#include <cstdlib>
#include <ctime>
#define N 500005
using namespace std;
int inline read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
int ch[N][2],val[N],pri[N],siz[N],sz;
void update(int x){siz[x]=1+siz[ch[x][0]]+siz[ch[x][1]];}
int new_node(int v)
{
    siz[++sz]=1;
    val[sz]=v;
    pri[sz]=rand();
    return sz;
}
int merge(int x,int y)
{
    if (!x || !y) return x+y;
    if (pri[x]<pri[y])
    {
        ch[x][1]=merge(ch[x][1],y);
        update(x);
        return x;
    }
    else
    {
        ch[y][0]=merge(x,ch[y][0]);
        update(y);
        return y;
    }
}
void split(int now,int k,int &x,int &y)
{
    if (!now) x=y=0;
    else
    {
        if (val[now]<=k)
            x=now,split(ch[now][1],k,ch[now][1],y);
        else
            y=now,split(ch[now][0],k,x,ch[now][0]);
        update(now);
    }
}
int kth(int now,int k)
{
    while(1)
    {
        if (k<=siz[ch[now][0]])
            now=ch[now][0];
        else
        if (k==siz[ch[now][0]]+1)
            return now;
        else
            k-=siz[ch[now][0]]+1,now=ch[now][1];
    }
}
main()
{
    srand((unsigned)time(NULL));
    int T,com,x,y,z,a,b,root=0;
    scanf("%d",&T);
    while(T--)
    {
        com=read(),a=read();
        if (com==1)
        {
            split(root,a,x,y);
            root=merge(merge(x,new_node(a)),y);
        }
        else
        if (com==2)
        {
            split(root,a,x,z);
            split(x,a-1,x,y);
            y=merge(ch[y][0],ch[y][1]);
            root=merge(merge(x,y),z);
        }
        else
        if (com==3)
        {
            split(root,a-1,x,y);
            printf("%d\n",siz[x]+1);
            root=merge(x,y);
        }
        else
        if (com==4)
            printf("%d\n",val[kth(root,a)]);
        else
        if (com==5)
        {
            split(root,a-1,x,y);
            printf("%d\n",val[kth(x,siz[x])]);
            root=merge(x,y);
        }
        else
        {
            split(root,a,x,y);
            printf("%d\n",val[kth(y,1)]);
            root=merge(x,y);
        }
    }
}