splay(平衡树)
今天学个splay吧(就用y总的普通平衡树试一下)
向秦大佬学习的一天,终于搞懂了splay,太激动了!!!
首先得感谢秦大佬 @秦淮岸灯火阑珊,写的巨好,还有yyb大佬–虽然不认识,但一定很厉害啊
还有我们的y总 @yxc,让我打下了treap的基础!!!
题目链接: ACwing253.普通平衡树
接下来是正片了
1. update更新操作(用来更新树的大小)
void update(int x)
{
t[x].size = t[t[x].ch[0]].size + t[t[x].ch[1]].size + t[x].cnt;
}
当前节点树的大小等于左子树的大小加上右子树的大小加上当前节点的个数
$$t[u.leftson].size + t[u.rightson].size + t[u].cnt$$
2. rotate旋转操作(这个应该是最重要的)
注:鼠标画图可能不是很好,希望大家多多担待哈 QWQ !
有位学长说的好啊:“这个操作就是像拉着X – Y这条链子拖下来or拽上去”,而且要保证二叉树的中序遍历不发生改变;
像图中一样:
X小于Z应该作为Z的左儿子, 而Y大于X,所以Y应该作为X的右儿子;
void rotate(int x)
{
int y = t[x].ff;//y为x的父亲
int z = t[y].ff;//z为x的祖父
int k = (t[y].ch[1] == x);//k是判断x是y的左还是右儿子
t[z].ch[t[z].ch[1] == y] = x;//将x接到z后面(x与之前y的位置相同,如图:都比z小)
t[x].ff = z;
t[y].ch[k] = t[x].ch[k ^ 1];//将y接到与之前x相反的位置,因为y > x;
t[t[x].ch[k ^ 1]].ff = y;
t[x].ch[k ^ 1] = y;
t[y].ff = x;
update(y), update(x); //更新操作
}
3.splay操作
因为会出现一条链的样子,这会使$O(logN)$的操作变成$O(N)$
splay会使它变得更”平衡“
void splay(int x, int goal)//将x旋转为goal的儿子
{
while (t[x].ff != goal)
{
int y = t[x].ff, z = t[y].ff;
if (z != goal)
(t[y].ch[0] == x) ^ (t[z].ch[0] == y) ? rotate(x) : rotate(y);
//如果是一条链的话就先旋y节点,否则就只旋x节点
rotate(x);
}
if (goal == 0) root = x;
}
4.还有find操作,大概就是将要找的数先旋到根节点
这个查询其实有点类似于二分查询的样子
void find(int x)
{
int u = root;
if (!u) return;//树是空的
while (t[u].ch[x > t[u].val] && x != t[u].val)
u = t[u].ch[x > t[u].val];
splay(u, 0);
}
5.splay支持的操作
(1). 插入一个数
void insert(int x)
{
int u = root, ff = 0;//从根节点开始
while (u && x != t[u].val)
{
ff = u;//一直往下找
u = t[u].ch[x > t[u].val];
}
if (u) t[u].cnt ++;//如果找到了这个数,就让$cnt++$;
else
{
u = ++ tot;//没找到就新开一个节点
if (ff) t[ff].ch[x > t[ff].val] = u;
t[u].ch[0] = t[u].ch[1] = 0;
//更新各个变量
t[tot].val = x;//更新数值
t[tot].ff = ff;//更新父节点
t[tot].size = 1;//更新子树大小
t[tot].cnt = 1;//更新数量
}
splay(u, 0);//记得一定要splay啊
}
(2). 查询前驱后继操作(pre_next)
首先用 $find$函数 将要做前驱后继操作的数找出来
然后找前驱后继;
因为前驱肯定小于$x$,所以从左子树找最大值,也就是小于x的最大值
后继同理;
int pre_next(int x, int f) //这里将前驱和后继函数合并, f == 0(前驱), f == 1(后继)
{
find(x);
int u = root;
if (t[u].val > x && f) return u; //后继
if (t[u].val < x && !f) return u; //前驱
u = t[u].ch[f];
while (t[u].ch[f ^ 1]) u = t[u].ch[f ^ 1];//反着找,小于x的max或者大于x的min
return u;
}
(3).删除操作(Delete) D要大写哦,记得是好像有重复关键字,remove也行,无所谓啦!
核心思想:
用前驱后继删除,这就是为啥先说前驱后继的原因啦;
先找到prev(前驱),将它splay到root(根节点),再找到next(后继),将它splay到prev后;
因为比prev大的数是next和当前节点在右子树,比next小的是prev和当前节点,在根节点和左子树,
以及右子树的左儿子;
故只需将根节点右子树的左儿子删掉即可 (是不是很巧妙, 快夸秦大佬和yyb大佬–虽然我不认识,但肯定很厉害)
void Delete(int x)
{
int prev = prev_next(x, 0);
int next = prev_next(x, 1);
splay(prev, 0), splay(next, prev);
int del = t[next].ch[0];
if (t[del].cnt > 1)
{
t[del].cnt --;
splay(del, 0);
}
else t[next].ch[0] = 0;
}
(4). 排名第k小的数
这个就比较简单了
看代码吧:
int kth(int x) //第k小数
{
int u = root;
if (t[u].size < x)//如果没这么多数,直接return;
return 0;
while (1)
{
int y = t[u].ch[0];
if (x > t[y].size + t[u].cnt)//如果排名大于左儿子的大小加上当前结点的数量,则一定在右子树
{
x -= t[y].size + t[u].cnt;//数量减少
u = t[u].ch[1];
}
else //否则就在左子树
{
if (t[y].size >= x)//节点足够那么就在左子树
u = y;
else return t[u].val;//否则就是当前节点
}
}
}
终于整完了,接下来放个完整的代码:
//splay
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 200010;
struct splay_tree
{
int ff, cnt;
int ch[2], val, size;
}t[N];
int root, tot;
inline void update(int x)
{
t[x].size = t[t[x].ch[0]].size + t[t[x].ch[1]].size + t[x].cnt;
}
inline void rotate(int x)
{
int y = t[x].ff;
int z = t[y].ff;
int k = (t[y].ch[1] == x);
t[z].ch[t[z].ch[1] == y] = x;
t[x].ff = z;
t[y].ch[k] = t[x].ch[k ^ 1];
t[t[x].ch[k ^ 1]].ff = y;
t[x].ch[k ^ 1] = y;
t[y].ff = x;
update(y), update(x);
}
inline void splay(int x, int goal)//将x旋转为goal的儿子
{
while (t[x].ff != goal)
{
int y = t[x].ff, z = t[y].ff;
if (z != goal)
(t[y].ch[0] == x) ^ (t[z].ch[0] == y) ? rotate(x) : rotate(y);
rotate(x);
}
if (goal == 0) root = x;
}
inline void find(int x)
{
int u = root;
if (!u) return;//树是空的
while (t[u].ch[x > t[u].val] && x != t[u].val)
u = t[u].ch[x > t[u].val];
splay(u, 0);
}
inline void insert(int x)
{
int u = root, ff = 0;
while (u && x != t[u].val)
{
ff = u;
u = t[u].ch[x > t[u].val];
}
if (u) t[u].cnt ++;
else
{
u = ++ tot;
if (ff) t[ff].ch[x > t[ff].val] = u;
t[u].ch[0] = t[u].ch[1] = 0;
t[tot].val = x;
t[tot].ff = ff;
t[tot].size = 1;
t[tot].cnt = 1;
}
splay(u, 0);
}
inline int prev_next(int x, int f) //这里将前驱和后继函数合并, f == 0(前驱), f == 1(后继)
{
find(x);
int u = root;
if (t[u].val > x && f) return u; //后继
if (t[u].val < x && !f) return u; //前驱
u = t[u].ch[f];
while (t[u].ch[f ^ 1]) u = t[u].ch[f ^ 1];
return u;
}
inline void Delete(int x)
{
int prev = prev_next(x, 0);
int next = prev_next(x, 1);
splay(prev, 0), splay(next, prev);
int del = t[next].ch[0];
if (t[del].cnt > 1)
{
t[del].cnt --;
splay(del, 0);
}
else t[next].ch[0] = 0;
}
inline int kth(int x) //第k小数
{
int u = root;
if (t[u].size < x)
return 0;
while (1)
{
int y = t[u].ch[0];
if (x > t[y].size + t[u].cnt)
{
x -= t[y].size + t[u].cnt;
u = t[u].ch[1];
}
else
{
if (t[y].size >= x)
u = y;
else return t[u].val;
}
}
}
int main()
{
int n;
scanf("%d", &n);
insert(1e9), insert(-1e9);
while (n -- )
{
int op, x;
scanf("%d%d", &op, &x);
if (op == 1) insert(x);
else if (op == 2) Delete(x);
else if (op == 3)
{
find(x);
printf("%d\n", t[t[root].ch[0]].size);
}
else if (op == 4) printf("%d\n", kth(x + 1));
else if (op == 5) printf("%d\n", t[prev_next(x, 0)].val);
else printf("%d\n", t[prev_next(x, 1)].val);
}
return 0;
}
删除操作中的else 后面是不是少了一句,应该splay一下,要不然当前节点和根节点的信息就不对了
orzorz
Orz
orzorzorz
巨