题目描述
难度:[绿]普及+/提高
输入n、k(1≤k≤n≤105)和长为n的数组a(0≤a[i]≤109)。
从a中选择一些数,但这些数中,不能有超过k个数的下标是连续的,即下标i,i+1,i+2,…,i+k不能都选。
输出你选的数的最大和。
输入样例
5 2
1
2
3
4
5
输出样例
12
算法
线段树优化DP
这个题直接正面求解并不是很好做,但是逆向思维求“所有不选的数累加和的最小值”就好做很多了。
状态定义
dp[i]表示选择a[i]的情况下,从前缀[1,i]中能选出来的数最小累加和。此时就要求任意两个被选的数,它们中间的间隔不能超过k。在这个定义下,答案就应该是Σni=1a[i]−minni=n−kdp[i]。第二项表示最后一个数选在[n−k,n]区间,就能保证后面再也凑不出长度至少为k的连续不选段。
状态转移
初始化dp[0]=0,这是空数组的情况。如果选择a[i],那么上一个数就应该在j∈[i−k−1,i−1]区间,如果j<i−k−1,那么j和i之间就有超过k个数没被选。因此,状态转移方程为dp[i]=minj∈[max(0,i−k−1),i−1]dp[j]+a[i]。发现这是一段动态的区间求最值操作,用线段树来存DP
数组就可以做到O(log2n)的单次转移。
复杂度分析
时间复杂度
状态数量为O(n),单次转移是O(log2n)的,因此整个算法的时间复杂度为O(nlog2n)。
空间复杂度
除了输入的数组a,主要的空间消耗就在于线段树的DP
数组。空间消耗和状态数量是一个级别的,因此整个算法的额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 100010;
const LL INF = 0x3f3f3f3f3f3f3f3f;
int n, k, a[N];
LL dp[N];
class SegmentTree {
public:
struct Info {
int l, r;
LL v;
Info() {}
Info(int left, int right, LL val): l(left), r(right), v(val) {}
} seg[N<<2];
explicit SegmentTree() {}
void build(int u, int l, int r) {
if(l == r) {
seg[u] = Info(l, r, INF);
}else {
int mid = l + r >> 1;
build(u<<1, l, mid);
build(u<<1|1, mid + 1, r);
pushup(u);
}
}
void modify(int pos, LL val) {
++pos;
modify(1, pos, val);
}
Info query(int l, int r) {
if(l > r) return Info(0, 0, 0);
++l, ++r;
return query(1, l, r);
}
private:
void modify(int u, int pos, LL val) {
if(seg[u].l == pos && seg[u].r == pos) {
seg[u] = Info(pos, pos, val);
}else {
int mid = seg[u].l + seg[u].r >> 1;
if(pos <= mid) modify(u<<1, pos, val);
else modify(u<<1|1, pos, val);
pushup(u);
}
}
Info query(int u, int l, int r) {
if(l <= seg[u].l && seg[u].r <= r) return seg[u];
int mid = seg[u].l + seg[u].r >> 1;
if(r <= mid) {
return query(u<<1, l, r);
}else if(mid < l) {
return query(u<<1|1, l, r);
}else {
return merge(query(u<<1, l, r), query(u<<1|1, l, r));
}
}
void pushup(int u) {
seg[u] = merge(seg[u<<1], seg[u<<1|1]);
}
Info merge(const Info& lchild, const Info& rchild) {
Info info;
info.l = lchild.l;
info.r = rchild.r;
info.v = min(lchild.v, rchild.v);
return info;
}
};
int main() {
scanf("%d%d", &n, &k);
LL tot = 0;
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
tot += a[i];
}
memset(dp, 0x3f, sizeof dp);
dp[0] = 0;
SegmentTree tr;
tr.build(1, 1, n + 1);
tr.modify(0, 0);
for(int i = 1; i <= n; i++) {
LL dpj = tr.query(max(0, i - k - 1), i - 1).v;
dp[i] = min(dpj + a[i], dp[i]);
tr.modify(i, dp[i]);
}
LL ans = 0;
for(int i = n - k; i <= n; i++) {
ans = max(ans, tot - dp[i]);
}
printf("%lld\n", ans);
return 0;
}