题目描述
难度:[绿]普及+/提高
输入n,L,R(1≤L≤R≤n≤2×105)和长为n+1的数组a(−103≤a[i]≤103),下标从0开始。保证a[0]=0。
有n+1个格子,编号从0到n。你从0号格子向右跳。如果你在格子i,可以跳到编号[i+L,i+R]中的任意格子。跳到格子i后,总得分增加a[i]。
如果跳出界(i>n),游戏结束。输出游戏结束后的最大总得分。
输入样例
5 2 3
0 12 3 11 7 -2
输出样例
11
算法
线段树优化DP
比较容易发现是DP
求解,i位置可以跳到[i+L,i+R]这个区间内的任何点。因此i位置的状态转移来源就是[i+L,i+R],也肯定是选最大值转移,这涉及到动态求区间最值,需要用线段树来维护DP
数组。而为了方便实用线段树,让数组的下标从1开始,整个游戏的目的是从1跳到>n+1。
状态定义
dp[i]表示从i跳到游戏结束能够获得的最大得分,在这个定义下,答案就是dp[1]。
状态转移
从后往前遍历i,计算每个dp[i],分为以下两种情况:
base case:i>n+1则有dp[i]=0,因为游戏已经结束了。
一般情况:否则i可以跳到[i+L,i+R],算上此时在i位置的得分a[i],状态转移方程为dp[i]=a[i]+maxj∈[i+L,i+R]dp[j],用线段树在O(log2n)的时间复杂度下把后面这个区间最值求出来即可。
复杂度分析
时间复杂度
状态数量是O(n),单次转移是O(log2n)的,因此整个算法的时间复杂度为O(nlog2n)。
空间复杂度
除了输入的数组a,空间消耗的瓶颈就在于线段树存储的dp
数组,空间复杂度仅和状态数量相关。因此,整个算法的额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 400010;
int n, L, R, a[N>>1];
class SegmentTree {
public:
struct Info {
int l, r, v;
Info() {}
Info(int left, int right, int val): l(left), r(right), v(val) {}
} seg[N<<2];
explicit SegmentTree() {}
void build(int u, int l, int r) {
if(l == r) {
seg[u] = Info(l, r, 0);
}else {
int mid = l + r >> 1;
build(u<<1, l, mid);
build(u<<1|1, mid + 1, r);
pushup(u);
}
}
void modify(int pos, int val) {
modify(1, pos, val);
}
Info query(int l, int r) {
if(l > r) return Info(0, 0, 0);
return query(1, l, r);
}
private:
void modify(int u, int pos, int val) {
if(seg[u].l == pos && seg[u].r == pos) {
seg[u] = Info(pos, pos, val);
}else {
int mid = seg[u].l + seg[u].r >> 1;
if(pos <= mid) modify(u<<1, pos, val);
else modify(u<<1|1, pos, val);
pushup(u);
}
}
Info query(int u, int l, int r) {
if(l <= seg[u].l && seg[u].r <= r) return seg[u];
int mid = seg[u].l + seg[u].r >> 1;
if(r <= mid) {
return query(u<<1, l, r);
}else if(mid < l) {
return query(u<<1|1, l, r);
}else {
return merge(query(u<<1, l, r), query(u<<1|1, l, r));
}
}
void pushup(int u) {
seg[u] = merge(seg[u<<1], seg[u<<1|1]);
}
Info merge(const Info& lchild, const Info& rchild) {
Info info;
info.l = lchild.l;
info.r = rchild.r;
info.v = max(lchild.v, rchild.v);
return info;
}
};
int main() {
scanf("%d%d%d", &n, &L, &R);
for(int i = 1; i <= n + 1; i++) {
scanf("%d", &a[i]);
}
SegmentTree dp;
dp.build(1, 1, 2*n + 1);
int ans = 0;
for(int i = n + 1; i >= 1; i--) {
dp.modify(i, a[i] + dp.query(i + L, i + R).v);
}
printf("%d\n", dp.query(1, 1).v);
return 0;
}