题目描述
难度分:1700
输入T(≤104)表示T组数据。所有数据的n之和≤3×105。
每组数据输入n(1≤n≤3×105),k(0≤k≤10)和长为n的数组a(1≤a[i]≤109)。
如下操作,至多执行k次:
选择a中两个相邻元素a[i]和a[i+1],更新a[i]=a[i+1],或者更新 a[i+1]=a[i]。
输出Σi∈[1,n]a[i]的最小值。
输入样例
4
3 1
3 1 2
1 3
5
4 2
2 2 1 3
6 3
4 1 2 2 4 3
输出样例
4
5
5
10
算法
动态规划
有三种情况是可以直接过滤的,n=1、k=0,以及数组最小值等于最大值,这三种情况直接打印原数组的和就好。对于一般情况,刚开始想的是贪心,但是WA
了一个很好的case
4 2
2 1 2 3
可以看到将a[3]变成a[2]和将a[4]变成a[3]的收益是一样的,如果先将a[4]变成a[3],数组就会变为2 1 2 2
,再将1旁边的一个2变成1,数组的总和就为6。但是如果先将a[3]变成a[2],数组就会变为2 1 1 3
,再把a[4]变成a[3]就能得到总和为5的数组2 1 1 1
。因此,操作顺序是会影响结果的,不能直接贪心,因此往DP
上考虑。由于k≤10特别小,估计在状态设计时状态的数量就是O(nk)。
状态定义
dp[i][rest]表示在还剩下rest次操作的情况下,后缀[i,n]总和减小量的最大值。因此答案就应该是Σi∈[1,n]a[i]−dp[1][k]。
状态转移
当[1,i)上都考虑完后,从i开始考虑,首先可以直接略过i位置,不对a[i]进行任何改动,状态转移方程为dp[i][rest]=dp[i+1][rest]。
如果要在[i,j]上进行改动的话,那最好的方式就是将[i,j]上的所有数变成最小值minidx∈[i,j]a[idx]。因此状态转移方程为dp[i][rest]=dp[i+x+1][rest−(x+1−mincnt)]+Σj∈[i,i+x]a[j]−minval×(x+1),其中minval=inidx∈[i,i+x]a[idx],mincnt是子数组[i,i+x]上等于minval的元素个数,在这个转移方程中x=0就是忽略i的情况,在保证rest≥x+1−mincnt的情况下选最大值转移即可。
注意
感觉在实现的时候有点问题,单次转移时不应该逐个遍历后面的元素,应该不断跳到下一个值不同的元素。
复杂度分析
时间复杂度
状态数量是O(nk),转移的时候要尝试消耗当前剩余的次数rest,它是O(k)级别的,因此时间复杂度为O(nk2)。
空间复杂度
空间消耗主要在于DP
矩阵,即状态数量的量级。因此,算法整体的额外空间复杂度为O(nk)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 300010, INF = 0x3f3f3f3f;
int t, n, k, a[N];
LL dp[N][11];
LL dfs(int i, int rest) {
if(i > n) return 0LL;
LL &v = dp[i][rest];
if(v != -1) return v;
LL res = 0LL, sum = 0;
int minval = INF, mincnt = 0;
for(int x = 0; i + x <= n; x++) {
sum += a[i + x];
if(a[i + x] < minval) {
minval = a[i + x];
mincnt = 1;
}else if(a[i + x] == minval) {
mincnt++;
}
LL len = x + 1;
if(len - mincnt > rest) break;
res = max(res, sum - len*minval + dfs(i + len, rest - (len - mincnt)));
}
return v = res;
}
void solve() {
LL tot = accumulate(a + 1, a + n + 1, 0LL);
int minval = *min_element(a + 1, a + n + 1);
int maxval = *max_element(a + 1, a + n + 1);
if(n == 1 || k == 0 || minval == maxval) {
printf("%lld\n", tot);
return;
}
printf("%lld\n", tot - dfs(1, k));
}
int main() {
scanf("%d", &t);
for(int cases = 1; cases <= t; cases++) {
scanf("%d%d", &n, &k);
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
for(int j = 0; j <= k; j++) {
dp[i][j] = -1;
}
}
solve();
}
return 0;
}