题目描述
难度:[紫]省选/NOI-
输入n(1≤n≤5×105),k(1≤k≤2×105)和长为n的数组a(0≤a[i]<232)。
求a的非空连续子数组异或和中的前k大之和。
保证a有k个非空连续子数组。
输入样例
3 2
1 2 3
输出样例
6
算法
Trie
+多路归并
先求一下数组a的前缀异或和s,对于任意子数组[l,r],它的元素异或和应该是s[r]⊕s[l−1]。所以问题就转化成了“s[i]⊕s[j](i,j∈[0,n])最大的k个异或值之和”。
考虑到此时(i,j)和(j,i)是完全等价的,因此我们找最大的2k个元素之和,最后再除以2会更方便一些。接下来对每个s[i]((i∈[0,n]))进行多路归并,将与s[i]异或的最大值xormax存入一个大根堆中。求与某个s[i]第rk大的异或值可以先将所有s[i]存入一个01
字典树中,然后按位贪心,有模板可以抄。
选择最大的top2k异或对,先取出堆顶t,它就是最大的异或对,将它的异或值累加到答案上。此时我们需要知道它是哪个i对应的s[i],以及这个异或值是s[i]与其他元素异或的第几大异或值rk。因此这个大根堆需要三个信息:i、rk,以及对应的异或值,并按照异或值排序。弹出堆顶之后(此时弹出的是(i,1,xormax)),s[i]就需要继续滚动,将与s[i]异或第二大的异或值加入到堆中,便于下次求全局第二大的异或值。周而复始,直到将top2k大的异或值累加到答案上,最后除以2打印出来即可。
复杂度分析
时间复杂度
最大数的数位大概就是十几位,因此每次往01
字典树中插入元素可以看成是个比较大的常数操作,构建01
字典树的时间复杂度是O(n)的。求s[i]与其他数异或的底rk大的异或值也是按位贪心地,时间复杂度也可以看成个比较大的常数(实际上是O(log10A),其中A是最大的异或值),因此遍历[0,n]初始化大根堆的时间复杂度为O(nlog2n)。
求top2k大的异或和需要循环O(k)次,每次循环可能会存在往大根堆中插入数据的操作,时间复杂度为O(log2n),时间复杂度为O(klog2n)。
综上,算法整体的时间复杂度为O((n+k)log2n)。
空间复杂度
主要的空间瓶颈在于01
字典树,需要开O(nlog10A)的空间,可以粗略看成是O(n)的。其余的前缀异或和数组s,子树大小数组sz都是O(n)级别的。因此,额外空间复杂度为O(n)。
C++ 代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
typedef long long LL;
const int N = 20000010;
struct Node {
int id, rk;
LL w;
bool operator <(const Node &a)const {
return w < a.w;
}
};
LL s[N];
int a[N][2], sz[N], n, k, tot;
void insert(LL x) {
int u = 0;
for(int i = 31; ~i; i--) {
int bit = (x>>i)&1;
sz[u]++; // 子树大小
if(!a[u][bit]) {
a[u][bit] = ++tot;
}
u = a[u][bit];
}
sz[u]++;
}
LL query(LL x, int rk) {
int u = 0;
LL ans = 0;
for(int i = 31; ~i; i--) {
int bit = (x>>i)&1;
if(!a[u][bit^1]) {
u = a[u][bit];
}else if(rk <= sz[a[u][bit^1]]) {
u = a[u][bit^1];
ans |= 1LL<<i;
}else {
rk -= sz[a[u][bit^1]];
u = a[u][bit];
}
}
return ans;
}
int main() {
scanf("%d%d", &n, &k);
priority_queue<Node> heap;
insert(s[0]);
for(int i = 1; i <= n; i++) {
scanf("%d", &s[i]);
s[i] ^= s[i - 1];
insert(s[i]);
}
for(int i = 0; i <= n; i++) {
heap.push((Node){i, 1, query(s[i], 1)});
}
LL ans = 0;
for(int i = 1; i <= 2*k; i++) {
Node t = heap.top();
ans += t.w;
heap.pop();
if(t.rk < n) {
heap.push((Node){t.id, t.rk + 1, query(s[t.id], t.rk + 1)});
}
}
printf("%lld\n", ans>>1);
return 0;
}