题目描述
著名的树套树板子题
下标线段树套值域平衡树
时间复杂度 $O(nlog^3n)$
C++ 代码
#include<cstdio>
#include<cstdlib>
#include<algorithm>
using namespace std;
const int N = 50010, INF = 0x3f3f3f3f;
struct FHQ{
int l, r;
int val, key, sz;
}fhq[N * 40];
int tot;
int get_node(int val)
{
int u = ++ tot;
fhq[u].l = fhq[u].r = 0;
fhq[u].val = val;
fhq[u].key = rand();
fhq[u].sz = 1;
return u;
}
void pushup(int u)
{
fhq[u].sz = fhq[fhq[u].l].sz + fhq[fhq[u].r].sz + 1;
}
void split(int u, int val, int &x, int &y)
{
if(!u) x = y = 0;
else
{
if(fhq[u].val <= val)
{
x = u;
split(fhq[u].r, val, fhq[u].r, y);
}
else
{
y = u;
split(fhq[u].l, val, x, fhq[u].l);
}
pushup(u);
}
}
int merge(int x, int y)
{
if(!x or !y) return x | y;
if(fhq[x].key > fhq[y].key)
{
fhq[x].r = merge(fhq[x].r, y);
pushup(x);
return x;
}
else
{
fhq[y].l = merge(x, fhq[y].l);
pushup(y);
return y;
}
}
int x, y, z;
void insert(int &root, int val)
{
split(root, val, x, y);
root = merge(merge(x, get_node(val)), y);
}
void dele(int &root, int val)
{
split(root, val - 1, x, y);
split(y, val, y, z);
y = merge(fhq[y].l, fhq[y].r);
root = merge(merge(x, y), z);
}
int get_rank(int &root, int val)
{
split(root, val - 1, x, y);
int res = fhq[x].sz;
root = merge(x, y);
return res;
}
int get_val(int &root, int k)
{
int u = root;
while(u)
{
if(fhq[fhq[u].l].sz + 1 == k) break;
if(fhq[fhq[u].l].sz >= k) u = fhq[u].l;
else
{
k -= fhq[fhq[u].l].sz + 1;
u = fhq[u].r;
}
}
return fhq[u].val;
}
int get_prev(int &root, int val)
{
split(root, val - 1, x, y);
int u = x;
while(fhq[u].r) u = fhq[u].r;
int res = fhq[u].val;
if(!x) res = -INF;
root = merge(x, y);
return res;
}
int get_next(int &root, int val)
{
split(root, val, x, y);
int u = y;
while(fhq[u].l) u = fhq[u].l;
int res = fhq[u].val;
if(!y) res = INF;
root = merge(x, y);
return res;
}
int n, m;
int w[N];
struct Node{
int l, r;
int root;
}tr[N << 2];
void build(int u, int l, int r)
{
tr[u] = {l, r};
for(int i = l; i <= r; i ++) insert(tr[u].root, w[i]);
if(l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}
int query_rank(int u, int l, int r, int val)
{
if(tr[u].l >= l and tr[u].r <= r) return get_rank(tr[u].root, val);
int mid = tr[u].l + tr[u].r >> 1, res = 0;
if(l <= mid) res = query_rank(u << 1, l, r, val);
if(r > mid) res += query_rank(u << 1 | 1, l, r, val);
return res;
}
int query_val(int x, int y, int k)
{
int l = 0, r = 1e8;
while(l < r)
{
int mid = l + r + 1 >> 1;
if(query_rank(1, x, y, mid) < k) l = mid;
else r = mid - 1;
}
return l;
}
void modify(int u, int x, int val)
{
dele(tr[u].root, w[x]);
insert(tr[u].root, val);
if(tr[u].l == tr[u].r) return;
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid) modify(u << 1, x, val);
else modify(u << 1 | 1, x, val);
}
int query_prev(int u, int l, int r, int val)
{
if(tr[u].l >= l and tr[u].r <= r) return get_prev(tr[u].root, val);
int mid = tr[u].l + tr[u].r >> 1, res = -INF;
if(l <= mid) res = query_prev(u << 1, l, r, val);
if(r > mid) res = max(res, query_prev(u << 1 | 1, l, r, val));
return res;
}
int query_next(int u, int l, int r, int val)
{
if(tr[u].l >= l and tr[u].r <= r) return get_next(tr[u].root, val);
int mid = tr[u].l + tr[u].r >> 1, res = INF;
if(l <= mid) res = query_next(u << 1, l, r, val);
if(r > mid) res = min(res, query_next(u << 1 | 1, l, r, val));
return res;
}
int main()
{
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i ++) scanf("%d", &w[i]);
build(1, 1, n);
while(m --)
{
int op, l, r, x, k, val;
scanf("%d", &op);
if(op == 1)
{
scanf("%d%d%d", &l, &r, &val);
printf("%d\n", query_rank(1, l, r, val) + 1);
}
else if(op == 2)
{
scanf("%d%d%d", &l, &r, &k);
printf("%d\n", query_val(l, r, k));
}
else if(op == 3)
{
scanf("%d%d", &x, &val);
modify(1, x, val);
w[x] = val;
}
else if(op == 4)
{
scanf("%d%d%d", &l, &r, &val);
printf("%d\n", query_prev(1, l, r, val));
}
else
{
scanf("%d%d%d", &l, &r, &val);
printf("%d\n", query_next(1, l, r, val));
}
}
return 0;
}