题目描述
给定一个整数数组 A
,以及一个整数 target
作为目标值,返回满足 i < j < k
且 A[i] + A[j] + A[k] == target
的元组 i, j, k
的数量。
由于结果会非常大,请返回 结果除以 10^9 + 7
的余数。
样例
输入:A = [1,1,2,2,3,3,4,4,5,5], target = 8
输出:20
解释:
按值枚举(A[i],A[j],A[k]):
(1, 2, 5) 出现 8 次;
(1, 3, 4) 出现 8 次;
(2, 2, 4) 出现 2 次;
(2, 3, 3) 出现 2 次。
输入:A = [1,1,2,2,2,2], target = 5
输出:12
解释:
A[i] = 1,A[j] = A[k] = 2 出现 12 次:
我们从 [1,1] 中选择一个 1,有 2 种情况,
从 [2,2,2,2] 中选出两个 2,有 6 种情况。
限制
3 <= A.length <= 3000
0 <= A[i] <= 100
0 <= target <= 300
算法
(枚举数字组合) O(n+w2)
- 从小到大枚举 i 和 j,计算 k=target−i−j,满足 i≤j≤k。
- 若 i=j=k,则答案数加上 cnt[i]∗cnt[j−1]∗cnt[k−2]/6。
- 若 i=j<k,则答案数加上 cnt[i]∗cnt[j−1]/2∗cnt[k]。
- 若 i<j=k,则答案数加上 cnt[i]∗cnt[j]∗cnt[k−1]/2。
- 若 i<j<k,则答案数加上 cnt[i]∗cnt[j]∗cnt[k]。
- 其中,cnt[x] 是数字 x 在数组
A
中出现的次数。
时间复杂度
- 统计 cnt 数组需要 O(n) 的时间,枚举数字需要 O(w2) 的时间,故总时间复杂度为 O(n+w2)。
C++ 代码
#define LL long long
class Solution {
public:
int threeSumMulti(vector<int>& A, int target) {
const int mod = 1000000007;
vector<int> cnt(101, 0);
int n = A.size(), ans = 0;
for (int i = 0; i < n; i++)
cnt[A[i]]++;
for (int i = 0; i <= target; i++)
for (int j = i; j <= target - i; j++) {
int k = target - i - j, cur;
if (k < j)
break;
if (k > 100)
continue;
if (i == j && j == k)
cur = (LL)(cnt[i]) * (cnt[j] - 1) * (cnt[k] - 2) / 6 % mod;
else if (i == j && j != k)
cur = (LL)(cnt[i]) * (cnt[j] - 1) / 2 * cnt[k] % mod;
else if (i != j && j == k)
cur = (LL)(cnt[i]) * cnt[j] * (cnt[k] - 1) / 2 % mod;
else
cur = (LL)(cnt[i]) * cnt[j] * cnt[k] % mod;
ans = (ans + cur) % mod;
}
return ans;
}
};
cnt[i]∗cnt[j−1]∗cnt[k−2]/6中的cnt[k−2]/6 代表什么意思?
去重,要除以 3!