题目描述
难度分:2265
输入n,m(1≤n,m≤300)和n行m列的矩阵A,元素范围[1,300]。
对于A的一个非空子矩阵R,定义f(R)=子矩阵的元素和×子矩阵的最小值。
输出f(R)的最大值。
输入样例1
3 3
5 4 3
4 3 2
3 2 1
输出样例1
48
输入样例2
4 5
3 1 4 1 5
9 2 6 5 3
5 8 9 7 9
3 2 3 8 4
输出样例2
231
输入样例3
6 6
1 300 300 300 300 300
300 1 300 300 300 300
300 300 1 300 300 300
300 300 300 1 300 300
300 300 300 300 1 300
300 300 300 300 300 1
输出样例3
810000
算法
前缀和+单调栈
预处理
先预处理出矩阵rminn×m×m,和矩阵sn×m。其中rmin[i][l][r]表示矩阵A的第i行中,第l列到第r列的最小值。比较简单,每行用动态规划即可求得,状态转移方程为
rmin[i][l][r]=min(rmin[i][l][r−1],A[i][r])
其中s[i][j]表示第i行前j列元素的和,主要用于求每行任意子数组的和。
枚举+单调栈
接下来枚举矩阵R的左边界l,以及宽度w。固定l和w之后,遍历所有行,构造两个数组csum1×n和cmin1×n。其中csum[i]表示矩阵A第i行,子数组[l,l+w−1]的累加和,cmin[i]表示矩阵A第i行,子数组[l,l+w−1]的最小值,即rmin[i][l][l+w−1]。
有了cmin,就可以枚举子矩阵中的最小元素cmin[i]。由于矩阵A中的所有元素都是正数,要想函数f(R)的值最大,在最小值确定的情况下只需要矩阵的累加和最大即可,因此利用单调栈计算出每个cmin[i]在行方向上能够保持最小值地位的最大范围。left[i]表示cmin[i]左边离自己最近的,且比自己小的数的行号;right[i]表示cmin[i]右边离自己最近的,且比自己小的数的行号。此时,以cmin[i]为最小值且f(R)最大的矩阵在行上的跨度就是(left[i],right[i]),f(R)=cmin[i]×Σright[i]−1k=left[i]+1csum[k],其中的求和操作可以利用csum数组的前缀和数组快速求得。
对于每个左边界l以及子矩阵宽度w,按照以上的单调栈算法维护所有f(R)的最大值即可。
复杂度分析
时间复杂度
预处理:每行的区间最小值rmin矩阵,时间复杂度为O(nm2);每行的前缀和矩阵s,时间复杂度为O(nm)。
算法:外层循环子矩阵宽度以及子矩阵的左端点,时间复杂度为O(m2);内层得到每行在列区间[l,r]的最小值数组cmin和区间和数组csum,时间复杂度为O(n);单调栈求cmin数组中每个元素cmin[i]左右两边最近的更小元素,时间复杂度为O(n);最后计算宽度为r−l+1的所有子矩阵R中f(R)的最大值,时间复杂度为O(n)。内层循环是有限几次与数据量无关的O(n)流程。
综上,算法整体的时间复杂度为O(nm2)。
空间复杂度
经过以上时间复杂度的分析,空间的瓶颈就在于三维矩阵rmin,因此额外空间复杂度为O(nm2)。
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <stack>
using namespace std;
typedef long long LL;
const int N = 301;
int a[N][N];
LL s[N][N];
int n, m, rmin[N][N][N];
int main() {
scanf("%d%d", &n, &m);
memset(rmin, 0x3f, sizeof rmin);
for(int i = 1; i <= n; i++) {
for(int j = 1; j <= m; j++) {
scanf("%d", &a[i][j]);
}
for(int j = 1; j <= m; j++) {
rmin[i][j][j] = a[i][j];
for(int k = j + 1; k <= m; k++) {
rmin[i][j][k] = min(rmin[i][j][k - 1], a[i][k]);
}
}
}
for(int i = 1; i <= n; i++) {
for(int j = 1; j <= m; j++) {
s[i][j] = s[i][j - 1] + a[i][j];
}
}
LL ans = 0;
for(int l = 1; l <= m; l++) {
for(int w = 1; l + w - 1 <= m; w++) {
vector<LL> cmin, csum;
for(int i = 1; i <= n; i++) {
int r = l + w - 1;
csum.push_back(s[i][r] - s[i][l - 1]);
cmin.push_back(rmin[i][l][r]);
}
for(int i = 1; i < n; i++) {
csum[i] += csum[i - 1];
}
vector<int> left(n, -1), right(n, n);
stack<int> stk;
for(int i = 0; i < n; i++) {
while(!stk.empty() && cmin[stk.top()] >= cmin[i]) {
right[stk.top()] = i;
stk.pop();
}
if(!stk.empty()) left[i] = stk.top();
stk.push(i);
}
for(int i = 0; i < n; i++) {
int lb = left[i], ub = right[i];
ans = max(ans, (csum[ub - 1] - (lb >= 0? csum[lb]: 0))*cmin[i]);
}
}
}
printf("%lld\n", ans);
return 0;
}