平衡树(treap)
平均O(logn)
例题
您需要写一种数据结构,来维护一些数,其中需要提供以下操作:
1.插入数值 x。
2.删除数值 x(若有多个相同的数,应只删除一个)。
3.查询数值 x 的排名(若有多个相同的数,应输出最小的排名)。
4.查询排名为 x 的数值。
5.求数值 x 的前驱(前驱定义为小于 x 的最大的数)。
6.求数值 x 的后继(后继定义为大于 x 的最小的数)。
int root, idx;
struct Node
{
int l, r;
int key, val;
int cnt, size;
} tr[N];
void pushup(int p)
{
tr[p].size = tr[tr[p].l].size + tr[tr[p].r].size + tr[p].cnt;
}
int get_node(int key)
{
tr[++idx].key = key; //节点存储值
tr[idx].val = rand(); //随机以保持平衡性
tr[idx].cnt = tr[idx].size = 1;
return idx;
}
void zig(int &p) //右旋
{
int q = tr[p].l;
tr[p].l = tr[q].r, tr[q].r = p, p = q;
pushup(tr[p].r), pushup(p);
}
void zag(int &p) //左旋
{
int q = tr[p].r;
tr[p].r = tr[q].l, tr[q].l = p, p = q;
pushup(tr[p].l), pushup(p);
}
void build()
{
get_node(-INF), get_node(INF); //哨兵
root = 1, tr[1].r = 2;
pushup(root);
if (tr[1].val < tr[2].val)
{
zag(root);
}
}
void insert(int &p, int key)
{
if (!p)
{
p = get_node(key);
}
else if (tr[p].key == key)
{
tr[p].cnt++;
}
else if (key < tr[p].key)
{
insert(tr[p].l, key);
if (tr[tr[p].l].val > tr[p].val)
{
zig(p);
}
}
else if (tr[p].key < key)
{
insert(tr[p].r, key);
if (tr[tr[p].r].val < tr[p].val)
{
zag(p);
}
}
pushup(p);
}
void remove(int &p, int key)
{
if (!p)
{
return;
}
if (tr[p].key == key)
{
if (tr[p].cnt > 1)
{
tr[p].cnt--;
}
else if (tr[p].l || tr[p].r)
{
if (!tr[p].r || tr[tr[p].l].val > tr[tr[p].r].val)
{
zig(p);
remove(tr[p].r, key);
}
else
{
zag(p);
remove(tr[p].l, key);
}
}
else
{
p = 0;
}
}
else if (key < tr[p].key)
{
remove(tr[p].l, key);
}
else if (tr[p].key < key)
{
remove(tr[p].r, key);
}
pushup(p);
}
int get_rank_by_key(int p, int key) //通过数值找最小排名
{
if (!p)
{
return 0;
}
if (tr[p].key == key)
{
return tr[tr[p].l].size + 1;
}
else if (key < tr[p].key)
{
return get_rank_by_key(tr[p].l, key);
}
else if (tr[p].key < key)
{
return tr[tr[p].l].size + tr[p].cnt + get_rank_by_key(tr[p].r, key);
}
}
int get_key_by_rank(int p, int rank) //通过排名找数值
{
if (!p)
{
return INF;
}
if (tr[tr[p].l].size >= rank)
{
return get_key_by_rank(tr[p].l, rank);
}
if (tr[tr[p].l].size + tr[p].cnt >= rank)
{
return tr[p].key;
}
return get_key_by_rank(tr[p].r, rank - tr[tr[p].l].size - tr[p].cnt);
}
int get_pre(int p, int key) //找到严格小于key的最大数
{
if (!p)
{
return -INF;
}
if (tr[p].key >= key)
{
return get_pre(tr[p].l, key);
}
return max(tr[p].key, get_pre(tr[p].r, key));
}
int get_next(int p, int key) //找到严格大于key的最小数
{
if (!p)
{
return INF;
}
if (tr[p].key <= key)
{
return get_next(tr[p].r, key);
}
return min(tr[p].key, get_next(tr[p].l, key));
}
int main()
{
std::ios::sync_with_stdio(0), std::cin.tie(0), std::cout.tie(0);
build();
cin >> n;
while (n--)
{
int op, x;
cin >> op >> x;
if (op == 1)
{
insert(root, x);
}
else if (op == 2)
{
remove(root, x);
}
else if (op == 3)
{
cout << get_rank_by_key(root, x) - 1 << endl; //左侧多了一个-INF
}
else if (op == 4)
{
cout << get_key_by_rank(root, x + 1) << endl; //左侧多了一个-INF
}
else if (op == 5)
{
cout << get_pre(root, x) << endl;
}
else if (op == 6)
{
cout << get_next(root, x) << endl;
}
}
return 0;
}