K大数查询问题
解题思路
本题给我们 $n$ 个位置,$m$ 个操作。每个位置中可以存储多个元素。
操作一共分两种,一种是在某一段位置中各加入一个元素,一种是查询某一段位置中的第 $k$ 大数。
可以发现本题的两个操作都是在某一段区间中进行操作的,想这类的问题启发我们考虑用树套树来做。
首先,每次都在指定区间内执行操作,因此外层树显然可以用线段树,然后有一个查询第 $k$ 大数的操作,因此可以考虑用内层树用平衡树。
但是本题的添加操作不是单点修改,而是区间修改,如果要对全部位置添加一个数的话,此时的添加操作的时间复杂度就是 $O(n)$,如果想让时间复杂度降回来,就需要用到懒标记,但是这里的懒标记维护的是每个区间需要加上哪些元素,需要维护的是一个 set,显然这很难写。
因此本题用线段树套平衡树并不是很好做,我们需要更换思路。
我们可以不用普通的线段树,而是用一棵权值线段树来作为外层树。线段树是建立在下标上的,而权值线段树顾名思义就是建立权值上的。要想在权值上建立线段树,首先我们就要对所有元素的数值进行离散化。最多有 $50000$ 个操作,因此离散化后最多只有 $50000$ 个权值。
普通线段树是以下标为端点,那么每个节点里面维护的应该是它的数值,而权值线段树是以权值为端点,那么每个节点里面维护的应该是它的下标。两种线段树是刚好相反的。
此时每个权值存储的是对应的所有下标,因此每一段权值区间中存储的就是权值区间中的所有权值对应的所有下标。
而对于每一段权值区间内部,我们再用另外一棵普通线段树来维护权值范围内的所有下标的个数 $cnt$。
因此本题需要用到的是一个线段树套线段树的做法。
然后我们再来考虑两个操作,对于给一段下标加上一个数 $c$,对应到权值线段树中就刚好反过来,变成给一个数值 $c$ 加上一段下标。我们成功将一个区间修改变成了单点修改,时间复杂度只有 $O(logn)$,此时我们应该怎样去修改呢,此时对于每一个节点我们需要进入到内层线段树去进行修改,相当于要在内层线段树下标在 $a \sim b$ 的范围内每一个位置都多了一个数,由于内层线段树维护的是每个下标的个数,因此相当于在 $a \sim b$ 中每个位置都 $+1$,这就是一个区间修改操作,我们可以用一个懒标记 $add$ 来维护。
然后再看查询操作,查询操作是询问下标为 $a \sim b$ 范围内的第 $c$ 大数是多少,这里可以用一个类似二分的思路来求,首先从根节点开始看,先看右子区间,设右子区间中下标为 $a \sim b$ 的数的个数为 $k$,如果 $k \geq c$,说明第 $c$ 大数一定在右半区间中,我们就可以递归到右半区间去找,否则说明第 $c$ 大数一定在左半区间中,我们就递归到左半区间去找。
综上所述,我们就可以用线段树套线段树来完成本题的操作。可以发现对于内层线段树来说,需要完成两个操作,一个是区间加上一个数,一个是区间求和,当线段树需要实现同一类操作时,我们就可以用一个叫做 “标记持久化” 的技巧,这个技巧就是我们不将懒标记往下传,这需要我们确定好每个信息的定义,首先有一个我们会维护一个 cnt 信息,表示每一段下标区间中的节点个数,因为要进行标记持久化,所以我们修改 cnt 的定义为只考虑当前区间以及它下面的所有子区间的懒标记的情况下的节点个数是多少。add 则表示当前区间的两个子区间中的每个数需要 $+$ add。
那么此时如果我们想求某一段区间的总和该怎么求呢,可以分成两部分,第一部分是该区间及以下子区间的和,也就是 cnt,第二部分就是该区间上面的所有区间的标记和,其实就是上面所有区间的 add 的和,记作 add’,而这一部分我们在从根节点递归下来的同时累加以下即可,此时这一部分的节点个数就是 add’ $\times$ length(当前区间的长度)。最后当前区间的节点个数就是两部分加在一起,也就是 cnt + add’ $\times$ length。这样我们就能在不用做 pushdown 也能求出每个区间的节点个数了,代码更简短一些。
另外需要注意,本题是线段树套线段树,内层和外层的线段树都需要用到 $n$ 的长度,因此整个的空间复杂度就是 $O(n^2)$,是开不了那么大的,因此本题还需要用到线段树的动态开点技巧,这样只开需要用到的节点,空间会减少很多。
以上就是本题的整个思路,思维难度不高,但是考察的知识点很多。
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long LL;
const int N = 50010, P = N * 17 * 17;
struct node
{
int l, r;
LL sum, add;
}tr[P]; //内层线段树
struct Node
{
int l, r;
int t; //当前区间对应的线段树的根节点下标
}TR[N * 4]; //外层线段树
int idx;
struct Query
{
int op, a, b, c;
}q[N]; //存储所有查询
int n, m;
vector<int> nums; //离散化用数组
int find(int x) //查找 x 离散化后的下标
{
int l = 0, r = nums.size() - 1;
while(l < r)
{
int mid = l + r + 1 >> 1;
if(nums[mid] <= x) l = mid;
else r = mid - 1;
}
return r;
}
int intersection(int a, int b, int c, int d) //求 [a, b] 和 [c, d] 之间的相交节点数(保证相交)
{
return min(b, d) - max(a, c) + 1;
}
void update(int u, int l, int r, int L, int R) //内层下标线段树:在下标 L ~ R 每个位置 + 1
{
tr[u].sum += intersection(l, r, L, R);
if(l >= L && r <= R)
{
tr[u].add++;
return;
}
int mid = l + r >> 1;
if(L <= mid)
{
if(!tr[u].l) tr[u].l = ++idx; //动态开点
update(tr[u].l, l, mid, L, R);
}
if(R > mid)
{
if(!tr[u].r) tr[u].r = ++idx;
update(tr[u].r, mid + 1, r, L, R);
}
}
LL get_sum(int u, int l, int r, int L, int R, LL add) //内层下标线段树:查询 L ~ R 的节点个数
{
if(l >= L && r <= R) return tr[u].sum + add * (r - l + 1);
int mid = l + r >> 1;
LL res = 0;
add += tr[u].add;
if(L <= mid)
{
if(tr[u].l) res += get_sum(tr[u].l, l, mid, L, R, add);
else res += add * intersection(l, mid, L, R); //如果左边节点不存在,说明没有被更新过
}
if(R > mid)
{
if(tr[u].r) res += get_sum(tr[u].r, mid + 1, r, L, R, add);
else res += add * intersection(mid + 1, r, L, R);
}
return res;
}
void build(int u, int l, int r) //外层权值线段树:初始化
{
TR[u] = {l, r, ++idx};
if(l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}
void modify(int u, int l, int r, int c) //外层权值线段树:在权值 c 中加入下标 l ~ r
{
update(TR[u].t, 1, n, l, r);
if(TR[u].l == TR[u].r) return;
int mid = TR[u].l + TR[u].r >> 1;
if(c <= mid) modify(u << 1, l, r, c);
else modify(u << 1 | 1, l, r, c);
}
int query(int u, int l, int r, LL c) //外层权值线段树:查询下标 l ~ r 中的第 c 大数
{
if(TR[u].l == TR[u].r) return TR[u].r;
int mid = TR[u].l + TR[u].r >> 1;
LL k = get_sum(TR[u << 1 | 1].t, 1, n, l, r, 0); //查询在右半区间中下标为 l ~ r 的节点个数
if(k >= c) return query(u << 1 | 1, l, r, c);
return query(u << 1, l, r, c - k);
}
int main()
{
scanf("%d%d", &n, &m);
for(int i = 0; i < m; i++)
{
scanf("%d%d%d%d", &q[i].op, &q[i].a, &q[i].b, &q[i].c);
if(q[i].op == 1) nums.push_back(q[i].c);
}
//离散化
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
build(1, 0, nums.size() - 1); //建立外层权值线段树
for(int i = 0; i < m; i++)
{
int op = q[i].op, a = q[i].a, b = q[i].b, c = q[i].c;
if(op == 1) modify(1, a, b, find(c)); //将 [a ~ b] 中每个位置加入一个 c
else printf("%d\n", nums[query(1, a, b, c)]); //查询 [a ~ b] 中的第 c 大数
}
return 0;
}
现在代码会被卡最后一个数据
已修~~爆int了,改一下longlong就行了
这个区间求交集是啥意思啊?
在内层线段树中,要快速更新区间信息,前面讲了计算信息的时候需要快速知道区间中的节点数量用来计算,所以这里直接函数来专门计算区间中的节点个数,因为只有当前区间和查询区间的交集才是我们要更新的部分。