树套树
解题思路
本题需要实现 $5$ 种操作,其中有一种单点修改操作,其他都是查询操作。
有两个操作是查找某个区间中 $< x$ 的最大数和 $> x$ 的最小数,也就是区间找前驱、后继。这两个操作加上单点修改操作,这三个操作我们可以用一个线段树套平衡树来实现。
如果只有这三个操作,那么我们甚至可以直接用线段树套 set 就能实现,具体思路可以参考 “平衡树-简单版”,但是本题除了这三个操作还多了两个查询操作。
一个操作是查询某个区间中的某个数的排名,一旦涉及到排名,我们就能用一个简单的 set 来实现了,只能手写一个平衡树来实现查找排名的操作。为了查找排名,我们只需要在平衡树中维护每个子树的大小信息 cnt 即可。
然后线段树会将我们要查询的区间分成若干个子区间,此时我们如何查询这若干个子区间拼凑成的区间中的排名呢,求一个数 $x$ 在区间中的排名其实就是求区间中有多少个数比 $x$ 小,因此我们可以分别求一下这若干个子区间中有多少个数比 $x$ 小,这些数量累加在一起,就是整个区间中比 $x$ 小的数的个数,就能得出 $x$ 的排名。至于怎么求每个子区间中有多少个数比 $x$ 小,由于每个子区间都是一个平衡树,而在平衡树中我们直接递归求就行了。
最后一个操作是查询某个区间中排名第 $k$ 的值,由于线段树将我们要查询的区间分成了若干个子区间,我们需要通过查询子区间里的信息来得到整个区间中排名第 $k$ 的数,可以发现我们并不能通过每个子区间中排名第 $k$ 的数是多少来综合得到整个区间排名第 $k$ 的数。
那么怎么求呢,这里可以用二分来求,由于对于一个数和它的排名,很显然数值越大排名越大,数值越小排名越小,这是满足单调性的,因此我们可以二分一下答案 $x$,然后我们用操作 $1$ 求一下当前 $x$ 的排名是多少,如果 $\leq k$ 说明答案一定 $\geq x$。如果 $> k$ 说明答案一定 $< x$。这样我们就能通过二分求出排名为 $k$ 的数。
通过以上思路就能完成本题的所有操作,本题时间复杂度的瓶颈是操作 $2$,需要 $O(log^3n)$ 的复杂度,因此整个的时间复杂度应该是 $O(mlog^3n)$
C++ 代码
#include <iostream>
#include <cstring>
using namespace std;
const int N = 1500010, INF = 0x3f3f3f3f;
struct node
{
int s[2], p, v; //子节点、父节点、数值
int cnt; //cnt 表示当前节点的子树中的节点数量
void init(int _v, int _p) //初始化节点
{
v = _v, p = _p;
cnt = 1;
}
}tr[N]; //平衡树(内层树)
int idx;
struct Node
{
int l, r;
int t; //t 表示区间 [l ~ r] 对应的平衡树的根节点
}TR[N]; //线段树(外层树)
int n, m;
int a[N];
void pushup(int x) //平衡树:用子节点信息更新当前节点信息
{
tr[x].cnt = tr[tr[x].s[0]].cnt + tr[tr[x].s[1]].cnt + 1;
}
void rotate(int x) //平衡树:左旋、右旋
{
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y, tr[y].p = x;
pushup(y), pushup(x);
}
void splay(int &root, int x, int k) //平衡树:将 x 旋到 k 的下面
{
while(tr[x].p != k)
{
int y = tr[x].p, z = tr[y].p;
if(z != k)
{
if((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
else rotate(y);
}
rotate(x);
}
if(!k) root = x;
}
void insert(int &root, int v) //平衡树:插入数值 v
{
int u = root, p = 0;
while(u) p = u, u = tr[u].s[v > tr[u].v];
u = ++idx;
if(p) tr[p].s[v > tr[p].v] = u;
tr[u].init(v, p);
splay(root, u, 0);
}
int get_x(int root, int v) //平衡树:找出 < v 的数的个数
{
int u = root;
int res = 0;
while(u)
{
if(tr[u].v < v) res += tr[tr[u].s[0]].cnt + 1, u = tr[u].s[1];
else u = tr[u].s[0];
}
return res;
}
void update(int &root, int p, int x) //平衡树:将 p 修改为 x
{
//找出 p 所在的位置
int u = root;
while(u)
{
if(tr[u].v == p) break;
if(tr[u].v < p) u = tr[u].s[1];
else u = tr[u].s[0];
}
splay(root, u, 0); //将 p 转到根节点
//找出 p 的前驱和后继
int l = tr[u].s[0], r = tr[u].s[1];
while(tr[l].s[1]) l = tr[l].s[1];
while(tr[r].s[0]) r = tr[r].s[0];
splay(root, l, 0), splay(root, r, l); //将 p 的前驱和后继旋上来
tr[r].s[0] = 0; //将 p 删去
pushup(r), pushup(l); //更新信息
insert(root, x); //将 x 插入进来
}
int get_pre(int root, int v) //平衡树:找出 < v 的最大值
{
int u = root, res = -INF;
while(u)
{
if(tr[u].v < v) res = max(res, tr[u].v), u = tr[u].s[1];
else u = tr[u].s[0];
}
return res;
}
int get_next(int root, int v) //平衡树:找出 > v 的最小值
{
int u = root, res = INF;
while(u)
{
if(tr[u].v > v) res = min(res, tr[u].v), u = tr[u].s[0];
else u = tr[u].s[1];
}
return res;
}
void build(int u, int l, int r) //线段树:初始化
{
TR[u] = {l, r};
insert(TR[u].t, -INF), insert(TR[u].t, INF); //给当前区间对应的平衡树插入哨兵
for(int i = l; i <= r; i++) insert(TR[u].t, a[i]); //将当前区间内的所有数插入对应的平衡树中
if(l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}
int query(int u, int l, int r, int x) //线段树:查询 [l ~ r] 中 < x 的数的个数
{
if(TR[u].l >= l && TR[u].r <= r) return get_x(TR[u].t, x) - 1; //返回 < x 的数的个数(减去左哨兵)
int mid = TR[u].l + TR[u].r >> 1;
int res = 0;
if(l <= mid) res += query(u << 1, l, r, x);
if(r > mid) res += query(u << 1 | 1, l, r, x);
return res;
}
void modify(int u, int p, int x) //线段树:修改 w[p] 为 x
{
update(TR[u].t, a[p], x); //将当前区间对应的平衡树中的 w[p] 删去,插入 x
if(TR[u].l == TR[u].r) return;
int mid = TR[u].l + TR[u].r >> 1;
if(p <= mid) modify(u << 1, p, x);
else modify(u << 1 | 1, p, x);
}
int query_pre(int u, int l, int r, int x) //线段树:查询 [l ~ r] 中 < x 的最大数
{
if(TR[u].l >= l && TR[u].r <= r) return get_pre(TR[u].t, x); //在当前区间对应的平衡树中找出 < x 的最大数
int mid = TR[u].l + TR[u].r >> 1;
int res = -INF;
if(l <= mid) res = max(res, query_pre(u << 1, l, r, x));
if(r > mid) res = max(res, query_pre(u << 1 | 1, l, r, x));
return res;
}
int query_next(int u, int l, int r, int x) //线段树:查询 [l ~ r] 中 > x 的最小数
{
if(TR[u].l >= l && TR[u].r <= r) return get_next(TR[u].t, x); //在当前区间对应的平衡树中找出 > x 的最小数
int mid = TR[u].l + TR[u].r >> 1;
int res = INF;
if(l <= mid) res = min(res, query_next(u << 1, l, r, x));
if(r > mid) res = min(res, query_next(u << 1 | 1, l, r, x));
return res;
}
int main()
{
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
build(1, 1, n);
while(m--)
{
int op, l, r, pos, x;
scanf("%d", &op);
if(op == 1) //查询 [l ~ r] 中 x 的排名
{
scanf("%d%d%d", &l, &r, &x);
printf("%d\n", query(1, l, r, x) + 1);
}
else if(op == 2) //查询 [l ~ r] 中排名为 x 的数
{
scanf("%d%d%d", &l, &r, &x);
//二分
int L = 0, R = 1e8;
while(L < R)
{
int mid = L + R + 1 >> 1;
if(query(1, l, r, mid) + 1 <= x) L = mid;
else R = mid - 1;
}
printf("%d\n", R);
}
else if(op == 3) //将 pos 位置上的数改为 x
{
scanf("%d%d", &pos, &x);
modify(1, pos, x);
a[pos] = x;
}
else if(op == 4) //查询 [l ~ r] 中 < x 的最大数
{
scanf("%d%d%d", &l, &r, &x);
printf("%d\n", query_pre(1, l, r, x));
}
else //查询 [l ~ r] 中 > x 的最小数
{
scanf("%d%d%d", &l, &r, &x);
printf("%d\n", query_next(1, l, r, x));
}
}
return 0;
}