相较线段树套平衡树,树状数组套权值线段树时间复杂度在查询上少了一个log,相应付出的代价就是空间多了一个log。
感觉称作套权值线段树更好些,虽然采取了一定可持久化的思路、操作,与主席树很像,但是毕竟仍有区别————由于我们是对树状数组的每个下标建立一个权值线段树,所以这里的每个权值线段树之间是互相独立的,假设对应的是树状数组的第 i 个下标位置,那么该权值线段树管理的是初始数组 a[] 的 (i−lowbit(i),i] 的区间值域信息 ;而主席树的每个权值线段树之间是相互依赖的,其第 i 颗权值线段树管理的是初始数组 [1,i]的区间值域信息。
树状数组套主席树,带pushup版 O(mlog2n)
- 部分注释(主要是pre,suf里的)是由于自己实现时因编写方式不妥致使出现了一些bug,调试了很长时间,后面想清楚了,为便于日后复习回忆才添加的,语言组织不太流畅,可以跳过不做阅读或参考
- 不带pushup的代码其里面的函数基本都是采取迭代实现
#include <bits/stdc++.h>
#define ls(q) tr[q].ls
#define rs(q) tr[q].rs
#define pc(c) putchar(c)
#define rep(a,b,c) for (int (a) = (b) ; (a) < (c) ; ++(a))
using namespace std;
using ll = long long ;
using pii = pair<int,int> ;
const int maxn = 5e4 + 10,inf = 0x7fffffff ;
struct pb{
int type,l,r,k;
}q[maxn]; // 存询问
int vec[maxn << 2],lt[maxn],llen,rt[maxn],rlen;
// 分别是查位置的数组,存要减去的 [1,l - 1]的数组 lt ,lt数组长度 llen ,存加上的 [1,r]的数组rt,rt数组的长度
int a[maxn],root[maxn],idx,len,n,m ;
// n、m一定要开全局,开在main里调了好久
// 分别为 初始数组,存根数组,动态开点的idx变量,vec的长度,初始数组长度n,询问个数m
struct node {
int ls,rs,sz = 0; // 左子下标ls,右子下标rs,当前子树节点个数sz
}tr[maxn << 8]; // 需要 nlognlogn 的空间
inline int get(int x){ // 查询x的位置,离散化了
return lower_bound(vec + 1,vec + len + 1,x) - vec;
}
inline int lowbit(int x){
return x & -x ;
}
inline void pushup(int q){ //更新节点信息
tr[q].sz = tr[ls(q)].sz + tr[rs(q)].sz ;
}
inline void refresh(int l,int r){ //更新 lt 、rt数组
llen = rlen = 0;
while (l) lt[++ llen] = root[l],l -= lowbit(l);
while (r) rt[++ rlen] = root[r],r -= lowbit(r);
}
inline void modify(int &q,int l,int r,int pos,int val){ //进行单点修改,权值线段树每个点对应的是值域而非初始数组区间信息
if (!q) q = ++idx ;
if (l ^ r){
int mid = l + r >> 1 ;
if ( pos <= mid ) modify(ls(q),l,mid,pos,val);
else modify(rs(q),mid + 1,r,pos,val);
pushup(q);
}else
tr[q].sz += val ; // 因为有pushup所以只在最后一个地方加val
}
inline void add(int pos,int val){ //往树状数组里面加值,树状数组每个点对应的是初始数组区间而非值域
for (int i = pos ; i <= n ; i += lowbit(i)) modify(root[i],1,len,a[pos],val);
}
inline int get_rk(int l,int r,int k,int res = 1,int mid = 0){ // 获取一个数的排名
refresh(l - 1,r); //先更新lt和rt数组
l = 1,r = len; //采用迭代方式,所以这里重复利用 l,r
while (l ^ r){
mid = l + r >> 1;
if (k <= mid) { // 表示要查的值小于等于值域中间的值,进入左半值域
rep(i,1,rlen + 1) rt[i] = ls(rt[i]);
rep(i,1,llen + 1) lt[i] = ls(lt[i]);
r = l + r >> 1;
continue ;
}
//表示要查的值在右侧值域,将左侧值域的sz先加入答案中
rep(i,1,rlen + 1) res += tr[ls(rt[i])].sz,rt[i] = rs(rt[i]);
rep(i,1,llen + 1) res -= tr[ls(lt[i])].sz,lt[i] = rs(lt[i]);
l = mid + 1;
} // 可以发现,由于我们只会在步入右区间的时候加值,所以我们只会将小于k的信息加入答案中,对于大于等于k的信息我们不会将其加入答案中
return res;
}
inline int get_k(int l,int r,int k,int sum = 0){ // 查询排名为k的数,值是多少
refresh(l - 1,r); // 更新lt、rt数组
l = 1,r = len; //采用迭代方式,所以这里重复利用 l,r
while (l ^ r) {
sum = 0 ;
rep(i,1,rlen + 1) sum += tr[ls(rt[i])].sz ; //+[1,r]的信息
rep(i,1,llen + 1) sum -= tr[ls(lt[i])].sz ; // - [1,l - 1]的信息
if (sum >= k) { //如果左子树的点的个数 >= k,说明答案在左子树
rep(i,1,rlen + 1) rt[i] = ls(rt[i]);
rep(i,1,llen + 1) lt[i] = ls(lt[i]);
r = l + r >> 1;
continue ;
}
// 否则在右子树
rep(i,1,rlen + 1) rt[i] = rs(rt[i]);
rep(i,1,llen + 1) lt[i] = rs(lt[i]);
k -= sum ; // 减去左子树的节点个数
l = (l + r >> 1) + 1;
}
return vec[l]; // 返回离散化后的下标对应的数值
}
inline int get_pre(int l,int r,int x){ // 获取前驱节点
int rk = get_rk(l,r,x) - 1 ;
// 先得到小于x(注意这里的x是离散化后的下标)的[l,r]区间内数的个数;这里不能把 -1 放在括号内,
// 如果放在括号内,就是获取小于x - 1的数的个数,然后在进行rk计算的时候,由于get_rk返回值至少为1,
// 所以还要进行一系列操作来处理,嘎嘎麻烦,不如直接-1计算出小于x的数的个数
return rk ? get_k(l,r,rk) : -inf ; // 如果没有小于x的数,就说明x是最小的了
}
inline int get_suf(int l,int r,int x){ // 获取后继节点
int rk = get_rk(l,r,x + 1) ;
// 获取小于x + 1(注意这里的x是离散化后的下标)的[l,r]区间内数的个数,如果返回值为r - l + 1 + 1,
// 就说明x + 1,在比区间中所有数都大,也就是[l,r]区间内不存在比x更大的,返回inf ;
// 否则 通过rk查询x + 1的值
// 不把 + 1 放到外面主要是 x 已经是离散化后的下标了,如果把 + 1 放到外面可能需要另加一系列操作才能
// 实现相同结果,不如我们把 + 1放在里面,统一好了操作来的方便
return rk == r - l + 2 ? inf : get_k(l,r,rk) ;
}
int main(){
ios::sync_with_stdio(false);
cin >> n >> m ;
vec[len] = -inf ; // 加入头哨兵
rep(i,1,n + 1) cin >> a[i],vec[++len] = a[i],root[i] = ++idx;
rep(i,1,m + 1) { //读入信息
cin >> q[i].type ;
if (q[i].type == 3)
cin >> q[i].l >> q[i].k;
else
cin >> q[i].l >> q[i].r >> q[i].k ;
if (q[i].type ^ 2) //如果不为2,说明k为可能出现在数组a中,将k加入vec中
vec[++len] = q[i].k;
}
vec[++ len] = inf ; // 加入结尾哨兵
sort(vec,vec + len);
len = unique(vec,vec + len + 1) - vec - 1; // 离散化
rep(i,1,n + 1) a[i] = get(a[i]),add(i,1);
// 将a数组节点信息插入树状数组中,注意这里是先转成离散化后的下标再插入,
// 这样操作的主要目的是 统一后面的一系列操作,简化代码
rep(i,1,m + 1){
if (q[i].type == 1) // 操作1,注意 k 要进行离散化
cout << get_rk(q[i].l,q[i].r,get(q[i].k)) << '\n';
else if (q[i].type == 2) //操作2,这里的k不用离散化
cout << get_k(q[i].l,q[i].r,q[i].k) << '\n';
else if (q[i].type == 3) //操作3,注意下标是q[i].l,然后k要进行离散化
add(q[i].l,-1),a[q[i].l] = get(q[i].k),add(q[i].l,1);
else if (q[i].type == 4) // 操作4,注意k要进行离散化
cout << get_pre(q[i].l,q[i].r,get(q[i].k)) << '\n';
else if (q[i].type == 5) //操作5,注意k要进行离散化
cout << get_suf(q[i].l,q[i].r,get(q[i].k)) << '\n';
}
return 0;
}
没有pushup版 O(mlog2n)
#include <bits/stdc++.h>
#define ls(q) tr[q].ls
#define rs(q) tr[q].rs
#define pc(c) putchar(c)
#define rep(a,b,c) for (int (a) = (b) ; (a) < (c) ; ++(a))
using namespace std;
using ll = long long ;
using pii = pair<int,int> ;
const int maxn = 5e4 + 10,inf = 0x7fffffff ;
struct pb{
int type,l,r,k;
}q[maxn];
int a[maxn],root[maxn],idx,len,n,m;
struct node {
int ls,rs,sz = 0;
}tr[maxn << 8]; // need nlognlogn space
int vec[maxn << 2],lt[maxn],llen,rt[maxn],rlen;
inline int get(int x){
return lower_bound(vec + 1,vec + len + 1,x) - vec;
}
inline int lowbit(int x){
return x & -x ;
}
inline void refresh(int l,int r){
llen = rlen = 0;
while (l) lt[++ llen] = root[l],l -= lowbit(l);
while (r) rt[++ rlen] = root[r],r -= lowbit(r);
}
inline void modify(int q,int l,int r,int pos,int val){
while (l ^ r){
tr[q].sz += val ;
int mid = l + r >> 1 ;
if ( pos <= mid ) {
if (!ls(q)) ls(q) = ++ idx ;
q = ls(q),r = mid ;
}else {
if (!rs(q)) rs(q) = ++ idx ;
q = rs(q),l = mid + 1;
}
}
tr[q].sz += val ;
}
inline void add(int pos,int val){
for (int i = pos ; i <= n ; i += lowbit(i)) modify(root[i],1,len,a[pos],val);
}
inline int get_rk(int l,int r,int k,int res = 1,int mid = 0){
refresh(l - 1,r);
l = 1,r = len ;
while (l ^ r){
mid = l + r >> 1;
if (k <= mid) {
rep(i,1,rlen + 1) rt[i] = ls(rt[i]);
rep(i,1,llen + 1) lt[i] = ls(lt[i]);
r = l + r >> 1;
continue ;
}
rep(i,1,rlen + 1) res += tr[ls(rt[i])].sz,rt[i] = rs(rt[i]);
rep(i,1,llen + 1) res -= tr[ls(lt[i])].sz,lt[i] = rs(lt[i]);
l = (l + r >> 1) + 1;
}
return res;
}
inline int get_k(int l,int r,int k,int sum = 0){
refresh(l - 1,r);
l = 1,r = len;
while (l ^ r) {
sum = 0 ;
rep(i,1,rlen + 1) sum += tr[ls(rt[i])].sz ;
rep(i,1,llen + 1) sum -= tr[ls(lt[i])].sz ;
if (sum >= k) {
rep(i,1,rlen + 1) rt[i] = ls(rt[i]);
rep(i,1,llen + 1) lt[i] = ls(lt[i]);
r = l + r >> 1;
continue ;
}
rep(i,1,rlen + 1) rt[i] = rs(rt[i]);
rep(i,1,llen + 1) lt[i] = rs(lt[i]);
k -= sum ;
l = (l + r >> 1) + 1;
}
return vec[l];
}
inline int get_pre(int l,int r,int x){
int rk = get_rk(l,r,x) - 1 ;
return rk ? get_k(l,r,rk) : -inf ;
}
inline int get_suf(int l,int r,int x){
int rk = get_rk(l,r,x + 1) ;
return rk == r - l + 2 ? inf : get_k(l,r,rk) ;
}
int main(){
ios::sync_with_stdio(false);
cin >> n >> m ;
vec[len] = -inf ;
rep(i,1,n + 1) cin >> a[i],vec[++len] = a[i],root[i] = ++idx;
rep(i,1,m + 1) {
cin >> q[i].type ;
if (q[i].type == 3)
cin >> q[i].l >> q[i].k;
else
cin >> q[i].l >> q[i].r >> q[i].k ;
if (q[i].type ^ 2)
vec[++len] = q[i].k;
}
vec[++ len] = inf ;
sort(vec,vec + len);
len = unique(vec,vec + len + 1) - vec - 1;
rep(i,1,n + 1) a[i] = get(a[i]),add(i,1);
rep(i,1,m + 1){
if (q[i].type == 1)
cout << get_rk(q[i].l,q[i].r,get(q[i].k)) << '\n';
else if (q[i].type == 2)
cout << get_k(q[i].l,q[i].r,q[i].k) << '\n';
else if (q[i].type == 3)
add(q[i].l,-1),a[q[i].l] = get(q[i].k),add(q[i].l,1);
else if (q[i].type == 4)
cout << get_pre(q[i].l,q[i].r,get(q[i].k)) << '\n';
else if (q[i].type == 5)
cout << get_suf(q[i].l,q[i].r,get(q[i].k)) << '\n';
}
return 0;
}