题目描述
给定三个长度均为 N 的整数数组:
A=[A1,A2,…,AN]
B=[B1,B2,…,BN]
C=[C1,C2,…,CN]
需要统计满足以下两个条件的三元组 (i,j,k) 的数量:
1. 1≤i,j,k≤N
2. Ai<Bj<Ck
样例
输入:
3
1 1 1
2 2 2
3 3 3
输出:
27
样例解释
A=[1,1,1], B=[2,2,2], C=[3,3,3]。
对于任意选择的 i,j,k(每个都有 3 种选择),都满足 Ai=1<Bj=2<Ck=3。
因此,总共有 3×3×3=27 个满足条件的三元组。
算法1 (排序 + 二分查找)
O(NlogN)
思路分析
-
问题核心: 统计满足严格递增关系 Ai<Bj<Ck 的三元组 (i,j,k) 的数量。
-
暴力枚举: 最直接的方法是三重循环枚举所有可能的 i,j,k 组合,然后检查是否满足 Ai<Bj<Ck。时间复杂度为 O(N3),对于 N=105 来说太慢了。
-
优化思路:固定中间元素: 我们可以尝试优化。考虑固定中间数组 B 的下标 j。对于一个固定的 Bj,我们需要找到:
- 有多少个下标 i 使得 Ai<Bj。
- 有多少个下标 k 使得 Ck>Bj。
假设满足 Ai<Bj 的 i 有count_A
个,满足 Ck>Bj 的 k 有count_C
个。那么,对于这个固定的 Bj,它能构成的满足条件的三元组数量就是count_A * count_C
。
最终的答案就是对所有 j (1≤j≤N) 计算count_A * count_C
并求和:
Total Count=N∑j=1(count of i where Ai<Bj)×(count of k where Ck>Bj)
-
高效计算 count_A 和 count_C: 如何快速计算对于给定的 Bj,有多少个 Ai<Bj 和 Ck>Bj?
- 如果我们将数组 A 和数组 C 排序,就可以使用二分查找来高效地完成计数。
- 计算
count_A(B_j)
: 在排序后的数组 A 中,找到第一个大于等于 Bj 的元素的位置。该位置之前的所有元素都满足 Ai<Bj。可以使用std::lower_bound
来找到这个位置。lower_bound(A_{start}, A_{end}, B_j)
返回指向第一个不小于 Bj 的元素的迭代器。从数组开始到这个迭代器之间的元素数量,就是满足 Ai<Bj 的 i 的数量。 - 计算
count_C(B_j)
: 在排序后的数组 C 中,找到第一个严格大于 Bj 的元素的位置。可以使用std::upper_bound
来找到这个位置。upper_bound(C_{start}, C_{end}, B_j)
返回指向第一个大于 Bj 的元素的迭代器。从这个迭代器开始到数组末尾的所有元素,都满足 Ck>Bj。元素的数量可以通过数组总长度 - upper_bound返回的位置索引
来计算。
-
算法步骤:
a. 读入 N 和三个数组 A,B,C。
b. 对数组 A 进行排序。
c. 对数组 C 进行排序。
d. 初始化总计数ans = 0
(使用long long
防止溢出)。
e. 遍历数组 B 的每个元素 Bj (下标 j 从 1 到 N):
i. 在排序后的 A 中,使用lower_bound
找到第一个 Ai≥Bj 的位置。计算出有多少个 Ai<Bj,记为count_A
。
ii. 在排序后的 C 中,使用upper_bound
找到第一个 Ck>Bj 的位置。计算出有多少个 Ck>Bj,记为count_C
。
iii. 将count_A * count_C
加到ans
上。
f. 输出ans
。 -
代码解读 (第一份代码):
num[0]
,num[1]
,num[2]
分别存储输入的数组 A, B, C (使用 1-based 索引存储,从下标 1 到 n)。- 对
num[0]
,num[1]
,num[2]
进行排序。注意:这里对 B 也进行了排序,这意味着接下来的循环变量i
对应的是排序后 B 数组的下标,而不是原始 B 数组的下标。但最终结果是一样的,因为我们是对 B 中的值进行计数,无论 B 是否排序,每个 B 值都会被考虑到。 for (int i = 1; i <= n; i ++ )
: 遍历排序后的 B 数组。key = num[1][i]
是当前考虑的 B 值。auto it1 = lower_bound(num[0] + 1, num[0] + n + 1, key) - num[0] - 1;
: 在排序后的 A (范围 1 到 n) 中查找。lower_bound(...) - (num[0] + 1)
得到的是从 1 开始计数的、值小于key
的元素个数。代码中的- num[0] - 1
最终也正确地计算了这个数量count_A
。auto it2 = upper_bound(num[2] + 1, num[2] + n + 1, key) - num[2];
: 在排序后的 C (范围 1 到 n) 中查找。it2
是第一个大于key
的元素在 1-based 数组中的下标。n - it2 + 1
: 计算大于key
的元素数量count_C
。 (总数 n - 小于等于key的数量(it2-1) = n - it2 + 1)。if (it1 >= 1 and it2 <= n)
: 这个条件判断似乎有些冗余或微妙,但核心逻辑是获取count_A
(it1
) 和count_C
(n-it2+1
)。it1
范围是[0, n]
,it2
范围是[1, n+1]
。ans += 1LL * it1 * (n - it2 + 1);
: 累加结果。1LL
保证乘法使用long long
。- 最终输出
ans
。
时间复杂度
- 读入数据: O(N)。
- 排序数组 A 和 C: O(N log N)。
- 遍历数组 B (N 次),每次执行两次二分查找 (O(log N)):O(N log N)。
- 总时间复杂度: O(N log N)。
参考文献
std::sort
std::lower_bound
std::upper_bound
C++ 代码 (第一份)
#include <bits/stdc++.h> // 引入所有标准库
using namespace std;
using i64 = long long; // 定义 i64 为 long long 的别名
constexpr int N = 1e5 + 10; // 定义常量 N,略大于数据范围上限
int num[3][N]; // 二维数组存储 A, B, C。num[0] 存 A, num[1] 存 B, num[2] 存 C
int main() {
// 关闭 C++ 标准 IO 流与 C stdio 的同步,提高 cin/cout 速度
ios::sync_with_stdio(false);
// 解除 cin 和 cout 的绑定,进一步提速
cin.tie(nullptr);
int n; // 输入数组长度 N
cin >> n;
// 读入三个数组,存储在 num[0], num[1], num[2] 的下标 1 到 n 处
for (int i = 0; i < 3; i ++ ) {
for (int j = 1; j <= n; j ++ ) {
cin >> num[i][j];
}
}
// 对 A, B, C 三个数组分别进行排序 (范围是下标 1 到 n)
for (int i = 0; i < 3; i ++ ) {
sort(num[i] + 1, num[i] + n + 1);
}
i64 ans = 0; // 初始化答案为 0,使用 long long 类型
// 遍历排序后的 B 数组中的每个元素
for (int i = 1; i <= n; i ++ ) {
int key = num[1][i]; // 当前 B 数组的元素值 B_j
// 在排序后的 A 数组 (num[0]) 中,查找第一个 >= key 的元素
// it1 计算的是严格小于 key 的元素的数量 (count_A)
auto it1 = lower_bound(num[0] + 1, num[0] + n + 1, key) - num[0] - 1;
// 在排序后的 C 数组 (num[2]) 中,查找第一个 > key 的元素
// it2 是该元素在 1-based 数组中的下标
auto it2 = upper_bound(num[2] + 1, num[2] + n + 1, key) - num[2];
// 计算严格大于 key 的元素的数量 (count_C = n - (it2 - 1))
i64 count_C = n - it2 + 1;
// (这里的 if 条件可能不是必需的,因为 it1 和 count_C 自然非负)
// if (it1 >= 1 and it2 <= n) { // 原代码的条件
if (it1 >= 0 && count_C >= 0) { // 更准确的条件是计数非负
// 累加贡献: count_A * count_C
// 使用 1LL 确保乘法以 long long 进行
ans += 1LL * it1 * count_C;
}
}
// 输出最终答案
cout << ans << "\n";
return 0; // 程序正常结束
}
算法2 (频次数组 + 前缀和)
O(N+V),其中 V 是数组元素的最大值
思路分析
-
核心思想: 同样是固定中间元素 Bj,计算有多少 Ai<Bj 和多少 Ck>Bj。但这次我们不使用排序和二分查找,而是利用频次数组和前缀和。
-
频次数组: 由于数组元素的值域范围不大 (0≤Ai,Bi,Ci≤105),我们可以使用一个频次数组
cnt
来统计每个值出现的次数。 -
前缀和: 对频次数组计算前缀和。令
s[x]
表示数值小于等于 x 的元素个数。即 s[x]=∑xv=0cnt[v]。 -
计算
count_A(B_j)
:- 首先,统计数组 A 中每个值出现的次数,存入
cnt_A
数组。 - 计算
cnt_A
的前缀和数组s_A
。 - 对于给定的 Bj,满足 Ai<Bj 的元素数量,就是值小于等于 Bj−1 的元素的数量,即
s_A[B_j - 1]
。
- 首先,统计数组 A 中每个值出现的次数,存入
-
计算
count_C(B_j)
:- 首先,统计数组 C 中每个值出现的次数,存入
cnt_C
数组。 - 计算
cnt_C
的前缀和数组s_C
。 - 对于给定的 Bj,满足 Ck>Bj 的元素数量,等于 C 中的总元素数量 N 减去值小于等于 Bj 的元素数量。即 N−sC[Bj]。
- 首先,统计数组 C 中每个值出现的次数,存入
-
算法步骤:
a. 读入 N 和数组 A,B,C。
b. (代码中) 将所有 Ai,Bi,Ci 的值加 1。这样做是为了方便处理,使得数值范围从 1 开始,避免下标 0 的歧义。现在的数值范围是 1≤A′i,B′i,C′i≤105+1。
c. 计算as[j] = count_A(B_j')
:
i. 创建频次数组cnt
和前缀和数组s
(大小为 V+1,其中 V=105+1)。
ii. 统计 A′ 的频次到cnt
。
iii. 计算cnt
的前缀和到s
。
iv. 对于每个 B′j,as[j] = s[B_j' - 1]$。 d. 计算
cs[j] = count_C(B_j’): i. 重置
cnt和
s数组为 0。 ii. 统计 $C'$ 的频次到
cnt。 iii. 计算
cnt的前缀和到
s。 iv. 对于每个 $B_j'$,
cs[j] = s[V] - s[B_j’](其中
s[V]是 C' 的总数 N)。 e. 初始化
res = 0(long long)。 f. 遍历 $j$ 从 0 到 $N-1$: i.
res += (long long)as[j] * cs[j]。 g. 输出
res`。 -
代码解读 (第二份代码):
a[N], b[N], c[N]
存储输入的数组 (0-based 索引)。as[N], cs[N]
分别存储对于每个b[i]
,小于它的 A 元素个数和大于它的 C 元素个数。cnt[N], s[N]
是辅助数组,N
在这里代表值域上限 105+10。scanf
用于读入。a[i] ++, b[i] ++, c[i] ++
: 对所有输入值加 1,值域变为 [1, 100001]。- 计算
as
:cnt[a[i]] ++
: 统计 A’ 的频次。s[i] = s[i - 1] + cnt[i]
: 计算频次前缀和。s[x]
= count of A’ values <= x.as[i] = s[b[i] - 1]
: 计算 count of A’ values < b[i].
- 计算
cs
:memset(cnt, 0, sizeof cnt); memset(s, 0, sizeof s);
: 重置辅助数组。cnt[c[i]] ++
: 统计 C’ 的频次。s[i] = s[i - 1] + cnt[i]
: 计算频次前缀和。s[x]
= count of C’ values <= x.cs[i] = s[N - 1] - s[b[i]];
: 计算 count of C’ values > b[i].N - 1
在这里是值域上限 (100001 - 1 = 100000 左右,s[N-1]
应该是s[MaxValue]
,代表总数 N)。
- 计算
res
: 循环累加(LL)as[i] * cs[i]
。 cout << res << endl;
: 输出结果。
时间复杂度
- 读入数据: O(N)。
- 值加 1: O(N)。
- 计算
as
: 统计频次 O(N),计算前缀和 O(V),计算as
数组 O(N)。总共 O(N + V)。 - 计算
cs
: 统计频次 O(N),计算前缀和 O(V),计算cs
数组 O(N)。总共 O(N + V)。 - 计算最终结果: O(N)。
- 总时间复杂度: O(N + V),其中 V 是数组元素的最大值 (105)。由于 V 和 N 在同一数量级,可以看作是 O(N)。
参考文献
- Frequency Array (Bucket Sort idea)
- Prefix Sums
C++ 代码 (第二份)
#include <iostream> // 使用 iostream 进行输入输出
#include <cstring> // 使用 cstring 中的 memset
#include <algorithm> // (可能未使用,但包含是好习惯)
using namespace std;
typedef long long LL; // 定义 LL 为 long long 的别名
const int N = 1e5 + 10; // 定义常量 N,表示数组大小和值域上限
int n; // 数组长度
int a[N], b[N], c[N]; // 输入数组 A, B, C
int as[N]; // as[i] 存储小于 b[i] 的 A 元素的数量
int cs[N]; // cs[i] 存储大于 b[i] 的 C 元素的数量
int cnt[N]; // 临时频次计数数组
int s[N]; // 临时前缀和数组
int main()
{
scanf("%d", &n); // 读入 N
// 读入 A, B, C 并将每个元素值加 1 (值域变为 [1, 100001])
for (int i = 0; i < n; i ++ ) scanf("%d", &a[i]), a[i] ++ ;
for (int i = 0; i < n; i ++ ) scanf("%d", &b[i]), b[i] ++ ;
for (int i = 0; i < n; i ++ ) scanf("%d", &c[i]), c[i] ++ ;
// --- 计算 as 数组 ---
// 统计数组 A' (加 1 后) 中每个值的频次
for (int i = 0; i < n; i ++ ) cnt[a[i]] ++ ;
// 计算频次的前缀和,s[x] 表示 A' 中值 <= x 的元素个数
for (int i = 1; i < N; i ++ ) s[i] = s[i - 1] + cnt[i];
// 对于每个 b[i]',小于它的 A' 元素的个数等于值 <= b[i]'-1 的元素个数
for (int i = 0; i < n; i ++ ) as[i] = s[b[i] - 1];
// --- 计算 cs 数组 ---
// 重置频次数组和前缀和数组
memset(cnt, 0, sizeof cnt);
memset(s, 0, sizeof s);
// 统计数组 C' (加 1 后) 中每个值的频次
for (int i = 0; i < n; i ++ ) cnt[c[i]] ++ ;
// 计算频次的前缀和,s[x] 表示 C' 中值 <= x 的元素个数
for (int i = 1; i < N; i ++ ) s[i] = s[i - 1] + cnt[i];
// 对于每个 b[i]',大于它的 C' 元素的个数 = 总数 N - (值 <= b[i]' 的元素个数)
// s[N-1] 在这里近似表示 C' 的总数 N (假设最大值不超过 N-1)
// 实际上应该是 s[MaxValue] 即 s[100001] 左右的值,它等于 N
for (int i = 0; i < n; i ++ ) cs[i] = s[N - 1] - s[b[i]]; // s[N-1] == n
// --- 计算最终结果 ---
LL res = 0; // 初始化结果为 0,使用 long long
// 遍历 B 数组,累加贡献
for (int i = 0; i < n; i ++ ) res += (LL)as[i] * cs[i]; // 强制类型转换确保乘法用 long long
cout << res << endl; // 输出结果
return 0; // 程序正常结束
}