题目描述
难度分:2006
输入n(2≤n≤5×105)和长为n的数组a(1≤a[i]≤109)。
有n只史莱姆排成一行,从左到右第i只史莱姆的体积为a[i]。假设你是第k只史莱姆,每次操作,你可以吃掉左右相邻的体积严格小于你的一只史莱姆,并获得它的体积。把你可以达到的最大体积记为f[k]。
输出f[1],f[2],f[3],…,f[n]。
输入样例1
6
4 13 2 3 2 6
输出样例1
4 30 2 13 2 13
输入样例2
12
22 25 61 10 21 37 2 14 5 8 6 24
输出样例2
22 47 235 10 31 235 2 235 5 235 6 235
算法
单调栈+动态规划
值域上做动态规划,状态就是我们要求的f[k]。对于一个史莱姆k,它左右两边体积严格大于自己的史莱姆分别在l和r位置,这两个位置可以通过单调栈算法预处理出来。
- 如果a[k]=a[k−1]=a[k+1],那它就被左右邻居卡死了,一步也动不了,f[k]=a[k]。
- 否则k至少能吃掉一个邻居,那么k就能吃掉区间(l,r)中所有体积不超过自己的史莱姆。如果a[k]>max(a[l],a[r]),那么f[k]=max(f[l],f[r]);如果a[k]>min(a[l],a[r]),那么f[k]=min(f[l],f[r]);否则f[k]=Σr−1i=l+1a[i],可以预处理出一个a的前缀和数组s,快速计算这段区间和。
遍历a数组构建一个有序映射mp,它表示“数值→该数值出现的索引位置”。单调栈预处理出每个史莱姆左右两边比自己体积大的最近史莱姆位置。最后倒序遍历mp进行DP
,先计算体积大的史莱姆的答案,这样才方便转移。
复杂度分析
时间复杂度
遍历a预处理出前缀和数组s的时间复杂度为O(n);单调栈预处理出l和r的时间复杂度为O(n);由于mp的插入是O(log2n)的,因此预处理出它的时间复杂度为O(nlog2n)。最后倒序遍历mp在值域上DP
其实就是遍历O(n)个坐标,时间复杂度还是O(n)。因此,整个算法的时间复杂度为O(nlog2n)。
空间复杂度
前缀和数组s,DP
数组f,以及每个史莱姆左右两边比自己体积大的史莱姆位置l和r都需要O(n)的空间来存储。mp映射表也需要存O(n)个索引。因此,整个算法的额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 500010;
int n, a[N], l[N], r[N];
LL s[N], f[N];
LL windowSum(int l, int r) {
return s[r] - s[l - 1];
}
int main() {
scanf("%d", &n);
map<int, vector<int>> mp;
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
mp[a[i]].push_back(i);
s[i] = s[i - 1] + a[i];
l[i] = 0, r[i] = n + 1;
}
stack<int> stk;
for(int i = 1; i <= n; i++) {
while(stk.size() && a[stk.top()] < a[i]) {
r[stk.top()] = i;
stk.pop();
}
stk.push(i);
}
while(stk.size()) stk.pop();
for(int i = n; i >= 1; i--) {
while(stk.size() && a[stk.top()] < a[i]) {
l[stk.top()] = i;
stk.pop();
}
stk.push(i);
}
for(auto it = mp.rbegin(); it != mp.rend(); it++) {
for(int x: it->second) {
int low = l[x], high = r[x];
bool ok = false;
for(int y = x - 1; y <= x + 1; y += 2) {
if(1 <= y && y <= n) {
if(a[y] < a[x]) {
ok = true;
}
}
}
if(ok) {
LL sum = windowSum(low + 1, high - 1);
if(low == 0 && high == n + 1) {
f[x] = sum;
}else if(low == 0) {
if(sum > a[high]) {
f[x] = f[high];
}else {
f[x] = sum;
}
}else if(high == n + 1) {
if(sum > a[low]) {
f[x] = f[low];
}else {
f[x] = sum;
}
}else {
if(sum > max(a[low], a[high])) {
f[x] = max(f[low], f[high]);
}else if(sum > min(a[low], a[high])) {
f[x] = min(f[low], f[high]);
}else {
f[x] = sum;
}
}
}else {
f[x] = a[x];
}
}
}
for(int i = 1; i <= n; i++) {
printf("%lld ", f[i]);
}
return 0;
}