莫欺少年穷,修魔之旅在这开始—>算法提高课题解
思路:
1. As everyone knows,Treap(平衡树) = BST(二叉搜索树) + heap(大根堆)
2. 本题难度不大,就是函数有点多,稍微耐心点即可做出
3. 右旋(zig)和左旋(zag)并不影响中序遍历
4. key 维护二叉搜索树,val 维护大根堆
5. 本题还有两个哨兵 -INF 和 INF,故 get_rank_by_key() 和 get_key_by_rank() 时需要留心
#include<bits/stdc++.h>
using namespace std;
const int N = 100010, INF = 0x3f3f3f3f;
int n;
int root,idx;
struct Node
{
int l,r;
int key,val;
int cnt,size;
}tr[N];
//维护每个节点有多少个数
void pushup(int p)
{
tr[p].size=tr[tr[p].l].size+tr[tr[p].r].size+tr[p].cnt;
}
//建立新节点
int get_node(int key)
{
tr[++idx].key=key;
tr[idx].val=rand();
tr[idx].cnt=tr[idx].size=1;
return idx;
}
//右旋
void zig(int &p)
{
int q=tr[p].l;
tr[p].l=tr[q].r,tr[q].r=p,p=q;
pushup(tr[q].r),pushup(q);
}
//左旋
void zag(int &p)
{
int q=tr[p].r;
tr[p].r=tr[q].l,tr[q].l=p,p=q;
pushup(tr[q].l),pushup(q);
}
//初始化平衡树
void build()
{
root=get_node(-INF);
tr[root].r=get_node(INF);
pushup(root);
if(tr[1].val<tr[2].val) zag(root);
}
//插入操作
void insert(int &p,int key)
{
if(!p) p=get_node(key);
else if(tr[p].key==key) tr[p].cnt++;
else if(tr[p].key>key)
{
insert(tr[p].l,key);
if(tr[tr[p].l].val>tr[p].val) zig(p);
}
else
{
insert(tr[p].r,key);
if(tr[tr[p].r].val>tr[p].val) zag(p);
}
//回溯更新
pushup(p);
}
//删除操作
void remove(int &p,int key)
{
if(!p) return;
if(tr[p].key==key)
{
if(tr[p].cnt>1) tr[p].cnt--;
else if(tr[p].l||tr[p].r)
{
if(!tr[p].r||tr[tr[p].l].val>tr[tr[p].r].val)
{
zig(p);
remove(tr[p].r,key);
}
else
{
zag(p);
remove(tr[p].l,key);
}
}
else p=0;
}
else if(tr[p].key>key) remove(tr[p].l,key);
else remove(tr[p].r,key);
//回溯更新
pushup(p);
}
//通过数值找排名
int get_rank_by_key(int p,int key)
{
if(tr[p].key==key) return tr[tr[p].l].size+1;
else if(tr[p].key>key) return get_rank_by_key(tr[p].l,key);
else return tr[tr[p].l].size+tr[p].cnt+get_rank_by_key(tr[p].r,key);
}
//通过排名找数值
int get_key_by_rank(int p,int rank)
{
if(tr[tr[p].l].size>=rank) return get_key_by_rank(tr[p].l,rank);
else if(tr[tr[p].l].size+tr[p].cnt>=rank) return tr[p].key;
else return get_key_by_rank(tr[p].r,rank-tr[tr[p].l].size-tr[p].cnt);
}
//找前驱(小于 key 的最大值)
int get_prev(int p,int key)
{
if(!p) return -INF;
if(tr[p].key>=key) return get_prev(tr[p].l,key);
return max(tr[p].key,get_prev(tr[p].r,key));
}
//找后继(大于 key 的最小值)
int get_next(int p,int key)
{
if(!p) return INF;
if(tr[p].key<=key) return get_next(tr[p].r,key);
return min(tr[p].key,get_next(tr[p].l,key));
}
int main()
{
build();
cin>>n;
while(n--)
{
int t,x;
cin>>t>>x;
//插入操作
if(t==1) insert(root,x);
//删除操作
else if(t==2) remove(root,x);
//通过数值找排名
else if(t==3) cout<<get_rank_by_key(root,x)-1<<endl;
//通过排名找数值
else if(t==4) cout<<get_key_by_rank(root,x+1)<<endl;
//找前驱(小于 key 的最大值)
else if(t==5) cout<<get_prev(root,x)<<endl;
//找后继(大于 key 的最小值)
else cout<<get_next(root,x)<<endl;
}
return 0;
}