算法1
(前缀和、暴力枚举) $O(n^3)$
记 $ps_i = a_0 \oplus a_1 \oplus \cdots \oplus a_{i - 1}$
根据异或的性质我们可以得到
$$ \begin{aligned} x &= a_i \oplus a_{i + 1} \oplus \cdots \oplus a_{j - 1} \\\ &= (a_0 \oplus a_1 \oplus \cdots \oplus a_{i - 1} ) \oplus (a_0 \oplus a_1 \oplus \cdots \oplus a_{i - 1} )\oplus (a_i \oplus a_{i + 1} \oplus \cdots \oplus a_{j - 1}) \\\ &= ps_i \oplus ps_j \end{aligned} $$
同理,$y = ps_j \oplus ps_{k + 1}$
做法:先预处理出异或前缀和,然后暴力枚举三元组 $(i, j, k)$,计算出相应的 $x$ 和 $y$,并判断 $x$ 和 $y$ 是否相等。
C++ 代码
class Solution {
public:
int countTriplets(vector<int>& a) {
int n = a.size();
vector<int> ps(n + 1);
for (int i = 0; i < n; ++i)
ps[i + 1] = ps[i] ^ a[i];
int ans = 0;
for (int i = 0; i < n; ++i) {
for (int j = i + 1; j < n; ++j) {
for (int k = j; k < n; ++k) {
int x = ps[i] ^ ps[j];
int y = ps[j] ^ ps[k + 1];
if (x == y) ++ans;
}
}
}
return ans;
}
};
Python 代码
class Solution:
def countTriplets(self, arr: List[int]) -> int:
n = len(arr)
ps = [0] * (n + 1)
for i in range(n):
ps[i + 1] = ps[i] ^ arr[i]
res = 0
for i in range(n - 1):
for j in range(i + 1, n):
for k in range(j, n):
if ps[i] ^ ps[j] == ps[j] ^ ps[k + 1]:
res += 1
return res
算法2
(前缀和) $O(n^2)$
注意到若 $x = y$,则必然有 $x \oplus y = 0$,即
$$
a_i \oplus a_{i + 1} \oplus \cdots \oplus a_{j - 1} \oplus a_j \oplus a_{j + 1} \oplus \cdots \oplus a_k = 0.
$$
所以 $ps_i \oplus (x \oplus y) = ps_i$。又 $ps_{k + 1} = ps_i \oplus (x \oplus y)$, 于是,我们只需要找到二元组 $(i,k)$ 使得 $ps_i = ps_{k + 1}$ 即可。而 $j$ 可以在区间 $[i + 1, k]$ 任选一个位置插入,有 $k - i$ 种选法。
C++ 代码
class Solution {
public:
int countTriplets(vector<int>& a) {
int n = a.size();
vector<int> ps(n + 1);
for (int i = 0; i < n; ++i)
ps[i + 1] = ps[i] ^ a[i];
int ans = 0;
for (int i = 0; i < n; ++i) {
for (int k = i + 1; k < n; ++k) {
if (ps[i] == ps[k + 1])
ans += k - i;
}
}
return ans;
}
};
算法3
(哈希表、前缀和) $O(n)$
继续优化:
对于下标 $k$ 来说,如果下标 $i = i_1, i_2, \cdots, i_n$ 时都满足 $ps_i = ps_{k + 1}$,那么根据上面的算法2,这些二元组 $(i_1, k), (i_2, k), \cdots (i_n, k)$ 对答案的贡献之和就是
$$ (k - i_1) + (k - i_2) + \cdots + (k - i_n) = nk - (i_1 + i_2 + \cdots + i_n) $$
做法:可以开两个哈希表cnt
和 sum
,在遍历下标 $k$ 的同时,用 $cnt$ 记录 $ps_k$ 的出现次数,用 $sum$ 记录 值为 $ps_k$ 的下标之和。
C++ 代码
class Solution {
public:
int countTriplets(vector<int>& a) {
int n = a.size();
unordered_map<int, int> cnt, sum;
int ans = 0, ps = 0;
for (int k = 0; k < n; ++k) {
int x = a[k];
if (cnt.count(ps ^ x))
ans += cnt[ps ^ x] * k - sum[ps ^ x];
++cnt[ps];
sum[ps] += k;
ps ^= x;
}
return ans;
}
};