题目描述
难度分:1568
输入n(1≤n≤2×105)和n行,每行三个数x,y,v,表示一个二维坐标点(x,y)和这个坐标点上的数字v (1≤x,y,v≤109)。
不在输入中的坐标点上的数字均为0。
请你选择一个坐标点(X,Y),累加所有横坐标为X的坐标点上的数字,以及所有纵坐标为Y的坐标点上的数字。
(X,Y)上的数字只累加一次。
(X,Y)不一定要在输入中。
输出累加值的最大值。
输入样例1
4
1 1 2
1 2 9
2 1 8
3 2 3
输出样例1
20
输入样例2
1
1 1000000000 1
输出样例2
1
输入样例3
15
158260522 877914575 602436426
24979445 861648772 623690081
433933447 476190629 262703497
211047202 971407775 628894325
731963982 822804784 450968417
430302156 982631932 161735902
880895728 923078537 707723857
189330739 910286918 802329211
404539679 303238506 317063340
492686568 773361868 125660016
650287940 839296263 462224593
492601449 384836991 191890310
576823355 782177068 404011431
818008580 954291757 160449218
155374934 840594328 164163676
输出样例3
1510053068
算法
贪心
显然我们需要尽可能选择“行累加和”与“列累加和”大的位置,可以先用两个哈希表rowsum、colsum构建映射“行 → 累加和”,“列 → 累加和”。接下来遍历一遍哈希表,把键值对取出来,构建两个二元组数组“(累加和,行)”row、“(累加和,列)”col,并对它们两个按照第一个关键字降序排列。
接下来双重循环遍历row和col两个数组,对于给定的行号r:
- 如果遍历到的列号c满足坐标(r,c)处是0(即不在所给的n个点中),由于行列都是按照累加和递增的顺序来遍历的,rowsum[r]+colsum[c]就是行为r时的最大累加和,维护最大值后退出对列的循环。
- 如果遍历到的列号c满足坐标(r,c)处不是0,就用rowsum[r]+colsum[c]−grid[r][c]维护当前最大值。
复杂度分析
时间复杂度
这个做法感觉最差的时间复杂度是O(n2),但是双重循环中只有两种情况:(1)网格(r,c)是题目给的点,上面是非零值;(2)是网格(r,c)上面是0,不是题目给的点。因为总共只有n个点,碰到第一种情况最多就是n次,而除此之外只要碰到过一次第二种情况就会更新最大值,从而退出内层对列的循环。因此,根据势能上的分析,更新最大值这一块的时间复杂度应该还是O(n),此时整个算法的瓶颈就在于排序,时间复杂度为O(nlog2n)。
空间复杂度
开辟了三个哈希表,一个用于存储行累加和,一个用于存储列累加和,最后一个用于存储n个网格点的数值,三个表都是O(n)级别的。因此,算法整体的额外空间复杂度为O(n)。
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <unordered_map>
using namespace std;
typedef long long LL;
int n;
struct pairHash {
template<typename T, typename U>
size_t operator()(const pair<T, U> &p) const {
return hash<T>()(p.first) ^ hash<U>()(p.second);
}
template<typename T, typename U>
bool operator()(const pair<T, U> &p1, const pair<T, U> &p2) const {
return p1.first == p2.first && p1.second == p2.second;
}
};
int main() {
scanf("%d", &n);
unordered_map<pair<int, int>, int, pairHash, pairHash> mp;
unordered_map<int, LL> rowsum, colsum;
for(int i = 1; i <= n; i++) {
int r, c, x;
scanf("%d%d%d", &r, &c, &x);
rowsum[r] += x;
colsum[c] += x;
mp[{r, c}] = x;
}
vector<pair<LL, int>> row, col;
for(auto&[r, val]: rowsum) row.push_back({val, r});
for(auto&[c, val]: colsum) col.push_back({val, c});
sort(row.begin(), row.end());
sort(col.begin(), col.end());
reverse(row.begin(), row.end());
reverse(col.begin(), col.end());
LL ans = 0;
for(auto&rtup: row) {
int r = rtup.second;
LL s1 = rtup.first;
for(auto&ctup: col) {
int c = ctup.second;
LL s2 = ctup.first;
if(mp.find({r, c}) == mp.end()) {
ans = max(ans, s1 + s2);
break;
}else {
ans = max(ans, s1 + s2 - mp[{r, c}]);
}
}
}
printf("%lld\n", ans);
return 0;
}