考虑前缀和 $s[i]$,原问题等价于对于一个点 $p$,找到 $s[1\cdots p-1]$ 的一个点 $p’$
使得 $s[p’] \oplus s[p]$ 最大,可以考虑使用可持久化 $\text{trie}$
对于可持久化 $\text{trie}$,$\text{root}(p-1)$ 这个版本中只有区间 $[1\cdots p-1]$ 的信息
利用主席树的思想,可以解决可持久化 $\text{trie}$ 的 $\text{k-query}$ 问题
对于区间 $[1\cdots r]$,要找到一个 $l \in [1\cdots r-1]$,使得 $s_l \oplus s_r$ 为第 $k$ 大
令 $p \leftarrow \text{root}(r-1), \ val \leftarrow s[r]$,从高位到低位检查 $val$ 的第 $b$ 位 $c$
如果 $\textbf{size}(t(p, c \oplus 1)) \geqslant k$,那么 $res += (1 \ll b), \ p \leftarrow t(p, c\oplus 1)$
否则的话,$k’ \leftarrow k - \textbf{size}(t(p, c\oplus 1))$,$p \leftarrow t(p, c)$,递归在 $t(p, c)$ 子树查找第 $k’$ 大
需要注意的是边界,想要让 $r = 1$ 时有意义,必须提前在 $\text{trie}$ 树中插入 $\text{insert}(\text{root}(0), 0)$
表示在 $\text{root}(0)$ 初始化插入一个每个位都是 $0$ 的数
具体来说
-
对于 $H$ 位的数 $val$,由于要统计 $\text{size}$ 信息,所以递归地插入
$\textbf{insert}(pre, p, H, val)$,递归的边界是 $H < 0, \text{size}(p) = \text{size}(pre) + 1$
($H = 0$ 时插入最后一个字符 $c$,递归执行 $t(p, c)$ 之后,边界 $H = -1$) -
$res \leftarrow (i, rk(i))$,表示在区间 $[1\cdots i-1]$ 中找到一个 $j$
使得 $res = s_j \oplus x_i$ 为第 $rk(i)$ 大,很显然一开始 $rk(i) = 1$ -
建立一个优先队列 $\text{que}$,对于 $\forall \ r \in [1, n]$
将 $\textbf{ask}(\text{root}(r-1), rk(r), s[r])$ 的结果放入 $\text{que}$ 中 -
取出堆顶元素,此时堆中最大元素假设为 $(res, p)$
表示此时 $\exists l \in [1, p)$,使得 $s_l \oplus s_p$ 为第 $1$ 大,其值为 $res$
将其累加到答案中,删掉 $s_l$,注意要接着找到 $[1, l-1] \cup [l+1, p)$ 中第 $1$ 大,将其放入堆中
注意到 $[1, l-1] \cup [l+1, p)$ 中的第 $1$ 大,等价于 $[1, p)$ 中的第 $2$ 大
由此在编程实现上可以更简单一些,一开始令 $rk(i) = 1$,取出堆顶元素 $(res, p)$ 之后
执行查询 $res’ \leftarrow \textbf{ask}(\text{root}(p), ++rk(p), s[p])$,再继续将 $res’$ 放入堆中
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <vector>
#include <stack>
#include <map>
#include <set>
#include <sstream>
#include <iomanip>
#include <cmath>
#include <bitset>
#include <assert.h>
#include <unordered_map>
#pragma GCC optimize(2)
using namespace std;
typedef long long ll;
#define Cmp(a, b) memcmp(a, b, sizeof(b))
#define Cpy(a, b) memcpy(a, b, sizeof(b))
#define Set(a, v) memset(a, v, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define _forS(i, l, r) for(set<int>::iterator i = (l); i != (r); i++)
#define _rep(i, l, r) for(int i = (l); i <= (r); i++)
#define _for(i, l, r) for(int i = (l); i < (r); i++)
#define _forDown(i, l, r) for(int i = (l); i >= r; i--)
#define debug_(ch, i) printf(#ch"[%d]: %d\n", i, ch[i])
#define debug_m(mp, p) printf(#mp"[%d]: %d\n", p->first, p->second)
#define debugS(str) cout << "dbg: " << str << endl;
#define debugArr(arr, x, y) _for(i, 0, x) { _for(j, 0, y) printf("%c", arr[i][j]); printf("\n"); }
#define _forPlus(i, l, d, r) for(int i = (l); i + d < (r); i++)
#define lowbit(i) (i & (-i))
#define MPR(a, b) make_pair(a, b)
pair<int, int> crack(int n) {
int st = sqrt(n);
int fac = n / st;
while (n % st) {
st += 1;
fac = n / st;
}
return make_pair(st, fac);
}
inline ll qpow(ll a, int n) {
ll ans = 1;
for(; n; n >>= 1) {
if(n & 1) ans *= 1ll * a;
a *= a;
}
return ans;
}
template <class T>
inline bool chmax(T& a, T b) {
if(a < b) {
a = b;
return true;
}
return false;
}
ll gcd(ll a, ll b) {
return b == 0 ? a : gcd(b, a % b);
}
ll ksc(ll a, ll b, ll mod) {
ll ans = 0;
for(; b; b >>= 1) {
if (b & 1) ans = (ans + a) % mod;
a = (a * 2) % mod;
}
return ans;
}
ll ksm(ll a, ll b, ll mod) {
ll ans = 1 % mod;
a %= mod;
for(; b; b >>= 1) {
if (b & 1) ans = ksc(ans, a, mod);
a = ksc(a, a, mod);
}
return ans;
}
template <class T>
inline bool chmin(T& a, T b) {
if(a > b) {
a = b;
return true;
}
return false;
}
template<class T>
bool lexSmaller(vector<T> a, vector<T> b) {
int n = a.size(), m = b.size();
int i;
for(i = 0; i < n && i < m; i++) {
if (a[i] < b[i]) return true;
else if (b[i] < a[i]) return false;
}
return (i == n && i < m);
}
// ============================================================== //
typedef pair<ll, int> PII;
const int maxn = 500000 + 10, N = maxn * 35;
const int H = 33;
int n, k, rk[maxn], root[maxn];
ll s[maxn];
priority_queue<PII> heap;
// insert(pre, p, H, val)
// ask(root(p), rk, &ans)
class Trie {
public:
int tot;
int t[N][2], sz[N];
Trie() {
tot = 0;
memset(t, 0, sizeof t);
memset(sz, 0, sizeof sz);
}
void insert(int pre, int p, int H, ll val) {
if (H < 0) {
sz[p] = sz[pre] + 1;
return;
}
int c = val >> H & 1;
if (pre) t[p][c^1] = t[pre][c^1];
t[p][c] = ++tot;
insert(t[pre][c], t[p][c], H-1, val);
sz[p] = sz[t[p][c]] + sz[t[p][c^1]];
}
void ask(int p, int rk, int H, ll val, ll &res) {
if (H < 0) return;
int c = val >> H & 1;
if (sz[ t[p][c^1] ] >= rk) {
res = (res << 1 | 1);
ask(t[p][c^1], rk, H-1, val, res);
}
else {
res <<= 1;
ask(t[p][c], rk - sz[t[p][c^1]], H-1, val, res);
}
}
} trie;
void solve() {
for (int i = 1; i <= n; i++) {
ll res = 0;
trie.ask(root[i-1], rk[i], H, s[i], res);
heap.push({res, i});
}
ll ans = 0;
while (k--) {
auto x = heap.top(); heap.pop();
ans += x.first;
int r = x.second;
ll res = 0;
trie.ask(root[r-1], ++rk[r], H, s[r], res);
heap.push({res, r});
}
printf("%lld\n", ans);
}
int main() {
freopen("input.txt", "r", stdin);
// init
memset(root, 0, sizeof root);
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; i++) {
ll x;
scanf("%lld", &x);
s[i] = s[i-1] ^ x;
rk[i] = 1;
}
// per trie
root[0] = ++trie.tot;
trie.insert(0, root[0], H, 0);
for (int i = 1; i <= n; i++) {
root[i] = ++trie.tot;
trie.insert(root[i-1], root[i], H, s[i]);
}
// solve
solve();
}