题目描述
给定一个非负整数数列 a,初始长度为 N。
请在所有长度不超过 M 的连续子数组中,找出子数组异或和的最大值。
子数组的异或和即为子数组中所有元素按位异或得到的结果。
注意:子数组可以为空。
输入格式
第一行包含两个整数 N,M。
第二行包含 N 个整数,其中第 i 个为 ai。
输出格式
输出可以得到的子数组异或和的最大值。
数据范围
对于 20% 的数据,1≤M≤N≤100
对于 50% 的数据,1≤M≤N≤1000
对于 100% 的数据,1≤M≤N≤105,0≤ai≤231−1
样例
输入样例:
3 2
1 2 4
输出样例:
6
算法1
(trie树) $O()$
blablabla
时间复杂度
参考文献
python3 代码
[n, m] = [int(x) for x in input().split()]
nums = [int(x) for x in input().split()]
class Trie:
def __init__(self):
self.child = [None for _ in range(2)]
self.cnt = 0
def insert(self, x: int) -> None:
root = self
for i in range(30, -1, -1):
ID = (x >> i) & 1
if root.child[ID] == None:
root.child[ID] = Trie()
root = root.child[ID]
root.cnt += 1
def delete(self, x: int) -> None:
root = self
for i in range(30, -1, -1):
ID = (x >> i) & 1
root = root.child[ID]
root.cnt -= 1
def query(self, x: int) -> int:
root = self
y = 0
for i in range(30, -1, -1):
ID = (x >> i) & 1
if root.child[(1-ID)] == None or root.child[(1-ID)].cnt <= 0:
y = y * 2 + ID
root = root.child[ID]
else:
y = y *2 + (1-ID)
root = root.child[(1-ID)]
return y ^ x
presum = [0 for _ in range(n + 1)]
for i in range(n):
presum[i+1] = presum[i] ^ nums[i]
res = 0
T = Trie()
for i in range(m):
T.insert(presum[i])
res = max(res, T.query(presum[i+1]))
for i in range(m, n):
T.insert(presum[i])
T.delete(presum[i-m])
res = max(res, T.query(presum[i+1]))
print(res)
C++ 代码
#include<bits/stdc++.h>
using namespace std;
class Trie
{
public:
Trie * child[2];
int cnt;
Trie()
{
memset(child, 0, sizeof(child));
this->cnt = 0;
}
void insert(int x)
{
Trie * root = this;
for (int i = 30; i > -1; i --)
{
int ID = (x >> i) & 1;
if (root->child[ID] == NULL)
root->child[ID] = new Trie();
root = root->child[ID];
root->cnt ++;
}
}
void dele(int x)
{
Trie * root = this;
for (int i = 30; i > -1; i --)
{
int ID = (x >> i) & 1;
root = root->child[ID];
root->cnt --;
}
}
int query(int x)
{
Trie * root = this;
int y = 0;
for (int i = 30; i > -1; i --)
{
int ID = (x >> i) & 1;
if (root->child[(1-ID)] == NULL || root->child[(1-ID)]->cnt <= 0)
{
root = root->child[ID];
y = y * 2 + ID;
}
else
{
root = root->child[(1-ID)];
y = y * 2 + (1-ID);
}
}
return y ^ x;
}
};
int main()
{
int n, m;
cin >> n;
cin >> m;
int nums[n];
memset(nums, 0, sizeof(nums));
for (int i = 0; i < n; i ++)
cin >> nums[i];
vector<int> presum(n + 1, 0);
for (int i = 0; i < n; i ++)
presum[i + 1] = presum[i] ^ nums[i];
Trie * T = new Trie();
int res = 0;
for (int i = 0; i < m; i ++)
{
T->insert(presum[i]);
res = max(res, T->query(presum[i+1]));
}
for (int i = m; i < n; i ++)
{
T->insert(presum[i]);
T->dele(presum[i - m]);
res = max(res, T->query(presum[i + 1]));
}
cout << res << endl;
return 0;
}