Splay(伸展树)是一种维护二叉搜索树的数据结构,可以用它干一些很神奇的东西,这篇文章先来介绍它的基本操作。
首先定义几个变量:
- fa[x] 表示 x 的父节点
- ch[x][y] 表示 x 的儿子节点,y=0 表示左儿子,y=1 表示右儿子
- cnt[x] 表示 x 这个数出现了几次
- val[x] 表示 x 节点的权值是多少
- size[x] 表示以 x 为根的树节点个数(树的大小)
- tot_size 表示树的总大小
- root 表示当前根节点是哪个
下面介绍操作:
clear(x)
把 x 节点上的所有信息清空
void clear(int x) {
fa[x]=ch[x][0]=ch[x][1]=cnt[x]=size[x]=val[x]=0;
}
get(x)
判断 x 节点为它父节点的左儿子还是右儿子(左0右1)
int get(int x) {
return ch[fa[x]][1]==x;
}
update(x)
维护以 x 为根的树的大小
在下面的操作的时候如果会update很多点,一定要从下往上维护。
void update(int x) {
if(x) {
size[x]=cnt[x];
if(ch[x][0]) size[x]+=size[ch[x][0]];
if(ch[x][1]) size[x]+=size[ch[x][1]];
}
}
rotate(x)
Splay中最最最重要的一个环节。
把 x 节点旋转到 x 的父节点的位置。
可是这是二叉树呀,这样操作不就乱了吗?
所以我们要维护某些节点之间的父子关系。
首先我们要明确在这次操作中会涉及到的节点:
1、x,你就是转它肯定会涉及它呀
2、fa[x],你要把 x 转到那里肯定也会涉及到它
3、fa[fa[x]],把 x 转到了 fa[x] 时,fa[fa[x]] 的儿子就不是 fa[x] 了,会变成 x
好了,rotate操作就会涉及到这 3 个节点,每个节点改变它的父亲和儿子,就会有六条语句,其中如果 fa[x] 已经是根了,那么就不用改变 fa[fa[x]] 的儿子了。
最关键的问题来了:父子关系怎么分配呢?
因为我们想把 x 到 fa[x] 的位置,那么它们的的父子关系必然会互换。唯一要确定的就是左右儿子的问题。
假设 x 是 fa[x] 的左儿子,那么 fa[x] 的原本的左儿子 x 将会变成 x 的右儿子(这里为什么是右儿子,因为这样才会保持二叉搜索树的性质)。反之亦然,所以我们用 which 来记录 x 与 fa[x] 的关系,最后维护一下旋转后的树的大小(因为 fa[x] 已经是 x 的儿子了,所以先update(fa[x])),代码如下:
void rotate(int x) {
int pa=fa[x],papa=fa[pa],which=get(x);
ch[pa][which]=ch[x][!which];fa[ch[x][!which]]=pa;
ch[x][!which]=pa;fa[pa]=x;
fa[x]=papa;if(papa) ch[papa][ch[papa][1]==pa]=x;
update(pa);update(x);
}
splay(x)
这个函数是通过不断的rotate把 x 转到根的位置。
注意三点一线的时候是先转fa[x]再转x
void splay(int x) {
for(int f;f=fa[x];rotate(x)) {
if(fa[f]) rotate(get(f)==get(x)?f:x);
}
root=x;
}
insert(x)
插入一个数x。
三种情况:
1、空树。直接改改信息return就好了。
2、x重复。cnt[x]++,维护一下return。
3、找到了最底下。新开节点维护一下return。
这个具体看代码吧,应该很好理解。
下面两种情况不要忘记splay一下。
void insert(int v) {
if(root==0) {
tot_size++;
ch[tot_size][0]=ch[tot_size][1]=fa[tot_size]=0;
val[tot_size]=v;
cnt[tot_size]=size[tot_size]=1;
root=tot_size;
return;
}
int f=0,now=root;
while(true) {
if(val[now]==v) {
cnt[now]++;
update(now);
update(f);
splay(now);
return;
}
f=now;
now=ch[now][val[now]<v];
if(now==0) {
tot_size++;
ch[tot_size][0]=ch[tot_size][1]=0;
fa[tot_size]=f;
val[tot_size]=v;
cnt[tot_size]=1,size[tot_size]=1;
ch[f][val[f]<v]=tot_size;
update(f);
splay(tot_size);
return;
}
}
}
find(x)
查找x这个数的排名
就按照二叉搜索树的性质往下查找,注意我们在往左子树找的时候是不用累加结果的,因为最左边的就是第一个,在往右边找的时候再加上左子树的大小,找到的时候别忘了把 x splay到根方便以后的操作。
int find(int x) {
int res=0,now=root;
while(true) {
if(x<val[now]) {
now=ch[now][0];
}
else {
res+=size[ch[now][0]];
if(x==val[now]) {
splay(now);
return res+1;
}
res+=cnt[now];
now=ch[now][1];
}
}
}
findx(x)
查找排名为x的树的节点
和find类似无非就是多判断一下子树的大小看看能否继续查找,temp表示的是已经搜了多少个节点。
int findx(int p) {
int now=root;
while(true) {
if(ch[now][0] && p<=size[ch[now][0]]) {
now=ch[now][0];
}
else {
int temp=size[ch[now][0]]+cnt[now];
if(p<=temp) return val[now];
p-=temp;
now=ch[now][1];
}
}
}
pre() 和 next()
查找根节点的前驱和后继节点
如果要查找x的前驱或后继的话,就先insert(x),把它转到根,再del(x),删除。
这个操作很简单,根节点的前驱就是根节点左子树中最靠右的那个,后继就是右子树中最靠左的那个,想一想,为什么?
int pre() {
int now=ch[root][0];
while(ch[now][1]) now=ch[now][1];
return now;
}
int next() {
int now=ch[root][1];
while(ch[now][0]) now=ch[now][0];
return now;
}
del(x)
删除大小为x的节点
首先我们随便find一下,目的是让x转到根节点,现在root就是x。
然后就会出现下面几种情况:
1、x有重复。那么直接cnt[root]--,return就好了。
2、root没有儿子了,即树上只有x一个节点。那么直接删除根节点,return。
3、root只有左儿子或只有右儿子。那就把它的这个儿子变成父亲,然后删除父亲,return。
4、root有两个儿子。那么为了满足二叉搜索树的性质,我们把根的前驱变成新的根,再把原来根的右子树接到新根的右儿子上,最后删除原来的根,维护一下新根,return。
void del(int x) {
int gg=find(x);
if(cnt[root]>1) {
cnt[root]--;
return;
}
if(!ch[root][0] && !ch[root][1]) {
clear(root);
root=0;
return;
}
if(!ch[root][0]) {
int oldroot=root;
root=ch[root][1];
fa[root]=0;
clear(oldroot);
return;
}
else if(!ch[root][1]) {
int oldroot=root;
root=ch[root][0];
fa[root]=0;
clear(oldroot);
return;
}
int oldroot=root;
splay(pre());
fa[ch[oldroot][1]]=root;
ch[root][1]=ch[oldroot][1];
clear(oldroot);
update(root);
return;
}
最后整合成一道模板题。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define MAXN 1000005
using namespace std;
int read() {
int x=0,f=1;char ch=getchar();
while(ch<'0' || ch>'9') {if(ch=='-')f=-1;ch=getchar();}
while(ch>='0' && ch<='9') {x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int fa[MAXN],cnt[MAXN],ch[MAXN][2],size[MAXN],val[MAXN],tot_size,root;
void clear(int x) {
fa[x]=ch[x][0]=ch[x][1]=cnt[x]=size[x]=val[x]=0;
}
int get(int x) {
return ch[fa[x]][1]==x;
}
void update(int x) {
if(x) {
size[x]=cnt[x];
if(ch[x][0]) size[x]+=size[ch[x][0]];
if(ch[x][1]) size[x]+=size[ch[x][1]];
}
}
void rotate(int x) {
int pa=fa[x],papa=fa[pa],which=get(x);
ch[pa][which]=ch[x][!which];fa[ch[x][!which]]=pa;
ch[x][!which]=pa;fa[pa]=x;
fa[x]=papa;if(papa) ch[papa][ch[papa][1]==pa]=x;
update(pa);update(x);
}
void splay(int x) {
for(int f;f=fa[x];rotate(x)) {
if(fa[f]) rotate(get(f)==get(x)?f:x);
}
root=x;
}
void insert(int v) {
if(root==0) {
tot_size++;
ch[tot_size][0]=ch[tot_size][1]=fa[tot_size]=0;
val[tot_size]=v;
cnt[tot_size]=size[tot_size]=1;
root=tot_size;
return;
}
int f=0,now=root;
while(true) {
if(val[now]==v) {
cnt[now]++;
update(now);
update(f);
splay(now);
return;
}
f=now;
now=ch[now][val[now]<v];
if(now==0) {
tot_size++;
ch[tot_size][0]=ch[tot_size][1]=0;
fa[tot_size]=f;
val[tot_size]=v;
cnt[tot_size]=1,size[tot_size]=1;
ch[f][val[f]<v]=tot_size;
update(f);
splay(tot_size);
return;
}
}
}
int find(int x) {
int res=0,now=root;
while(true) {
if(x<val[now]) {
now=ch[now][0];
}
else {
res+=size[ch[now][0]];
if(x==val[now]) {
splay(now);
return res+1;
}
res+=cnt[now];
now=ch[now][1];
}
}
}
int findx(int p) {
int now=root;
while(true) {
if(ch[now][0] && p<=size[ch[now][0]]) {
now=ch[now][0];
}
else {
int temp=size[ch[now][0]]+cnt[now];
if(p<=temp) return val[now];
p-=temp;
now=ch[now][1];
}
}
}
int pre() {
int now=ch[root][0];
while(ch[now][1]) now=ch[now][1];
return now;
}
int next() {
int now=ch[root][1];
while(ch[now][0]) now=ch[now][0];
return now;
}
void del(int x) {
int gg=find(x);
if(cnt[root]>1) {
cnt[root]--;
return;
}
if(!ch[root][0] && !ch[root][1]) {
clear(root);
root=0;
return;
}
if(!ch[root][0]) {
int oldroot=root;
root=ch[root][1];
fa[root]=0;
clear(oldroot);
return;
}
else if(!ch[root][1]) {
int oldroot=root;
root=ch[root][0];
fa[root]=0;
clear(oldroot);
return;
}
int oldroot=root;
splay(pre());
fa[ch[oldroot][1]]=root;
ch[root][1]=ch[oldroot][1];
clear(oldroot);
update(root);
return;
}
int main() {
int n=read();
while(n--) {
int opt=read(),x=read();
if(opt==1) insert(x);
if(opt==2) del(x);
if(opt==3) printf("%d\n",find(x));
if(opt==4) printf("%d\n",findx(x));
if(opt==5) {
insert(x);
printf("%d\n",val[pre()]);
del(x);
}
if(opt==6) {
insert(x);
printf("%d\n",val[next()]);
del(x);
}
}
return 0;
}