题目描述
难度分:1506
输入n(2≤n≤2×105)和长为n的数组a(1≤a[i]≤106)。下标从1开始。
定义f(i,j)=⌊max(a[i],a[j])min(a[i],a[j])⌋。
输出所有f(i,j)之和,其中1≤i<j≤n。
输入样例1
3
3 1 4
输出样例1
8
输入样例2
6
2 7 1 8 2 8
输出样例2
53
输入样例3
12
3 31 314 3141 31415 314159 2 27 271 2718 27182 271828
输出样例3
592622
算法
整除分块
思路比较容易想到,先让数组a排序然后逐个将每个元素作为最小值会更好考虑。而为了使用m+⌊m2⌋+⌊m3⌋+…+⌊mm⌋~
lnm这个结论,先把a中的数组放入一个vals数组中排序去重,然后遍历vals按照数值去考虑,否则遍历原数组a一旦有大量重复的小值,这个结论就会失效,重新退化为O(m)的时间复杂度。
对于vals中的一个元素x,枚举倍数t(满足tx≤maxi∈[1,n]a[i])。原数组中所有在区间[tx,(t+1)x)里面的数除以x向下取整就能得到t,因此对答案的贡献为counter[x]×t×(s[(t+1)x−1]−s[tx−1])。其中counter[num]表示数值num在原数组中的出现次数,s[num]表示原数组中≤num的数值有多少个。因此,还需O(n)预处理出counter,O(m)预处理出s数组,其中m为a的最大值。
但其实这里还没有做完,因为这样会使得t=1的情况被重复计算。a[i]会跟前面与自己相等的数计算贡献,也会与后面与自己相等的数计算贡献。所以遍历a进行去重,当遍历到a[i]的时候让答案减去counter[a[i]],此时counter[a[i]]表示前缀[1,i]中a[i]的频数。这样就能去掉a[i]与自己,以及前面与自己相等的数进行配对所产生的贡献。
复杂度分析
时间复杂度
对数组a排序的时间复杂度为O(nlog2n),对其去重的时间复杂度为O(n)(即对vals排序去重)。预处理a值域的前缀和数组s的时间复杂度为O(m),m表示a数组的最大值。对每个vals求答案,时间复杂度为O(m+⌊m2⌋+⌊m3⌋+…+⌊mm⌋),大约是O(logm),因此把所有vals[i]算完时间复杂度为O(nlogm)。最后还需要遍历一遍数组a,对重复计算的部分去重,时间复杂度为O(n)。
综上,整个算法的时间复杂度为O(n(log2n+logm))。
空间复杂度
vals的空间消耗是线性的,为O(n)。a的频数表counter也是线性的,为O(n)。前缀和数组s的空间复杂度与数组a的值域相关,为O(m)。因此,整个算法的额外空间复杂度为O(n+m)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 200010, M = 2000010;
int n, a[N], s[M];
struct custom_hash {
static uint64_t splitmix64(uint64_t x) {
x += 0x9e3779b97f4a7c15;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
}
size_t operator()(uint64_t x) const {
static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
return splitmix64(x + FIXED_RANDOM);
}
};
int main() {
scanf("%d", &n);
vector<int> vals;
unordered_map<int, int, custom_hash> counter;
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
counter[a[i]]++;
vals.push_back(a[i]);
}
sort(vals.begin(), vals.end());
vals.erase(unique(vals.begin(), vals.end()), vals.end());
sort(a + 1, a + n + 1);
for(int i = 0; i <= 2*a[i]; i++) {
s[i] = 0;
}
for(int i = 1; i <= 2*a[n]; i++) {
s[i] = s[i - 1] + (counter.count(i)? counter[i]: 0);
}
LL ans = 0;
for(auto&[num, freq]: counter) {
for(LL t = 1; t*num <= a[n]; t++) {
// [t*num,(t+1)*num)中有多少个数,这些数除以num都得t
ans += (s[(t + 1)*num - 1] - s[t*num - 1]) * t * freq;
}
}
counter.clear();
for(int i = 1; i <= n; i++) {
counter[a[i]]++;
ans -= counter[a[i]];
}
printf("%lld\n", ans);
return 0;
}