题目描述
难度分:1800
输入n、m(m<n≤2×106,1≤m≤5000)和长为n的数组r(−m≤r[i]≤m),其中恰好有m个0。
一开始x=y=0。从左到右遍历r:
- 如果r[i]=0,你可以把x加一,或者把y加一。
- 如果r[i]<0且x≥|r[i]|,那么得1分。
- 如果r[i]>0且y≥|r[i]|,那么得1分。
输出最大总得分。
输入样例1
10 5
0 1 0 2 0 -3 0 -4 0 -5
输出样例1
3
输入样例2
3 1
1 -1 0
输出样例2
0
输入样例3
9 3
0 0 1 0 2 -3 -2 -2 1
输出样例3
4
算法
动态规划
这个题乍一看没什么思路,因为这个n太大了,但是m≤5000是比较小的,感觉就应该从m入手来设计算法。而注意到只有r[i]=0的时候才能对x或y自增,所以x或y最终可以变得的大小也不会很大,隐隐感觉要用DP
来做。
状态定义
dp[i][x]表示当前考虑第i个0(从第0个0开始),且前面x的值大小为x的情况下,考虑完后面的所有0能够得到的最大得分。在这个定义下,答案就是dp[0][0],从第0个0开始考虑,初始的x=0。
状态转移
而到了第i个0时,前面已经经过了i个0,所以从前位置的x值也可以推测出y=i−x。这时候有两种策略,要么当前对x自增,要么当前对y自增,状态转移方程分别为dp[i][x]=cnt1+dp[i+1][x+1],dp[i][x]=cnt2+dp[i+1][x],两种情况选较大值转移。其中cnt1=get1(i,i+1,x+1)+get2(i,i+1,y),cnt2=get1(i,i+1,x)+get2(i,i+1,y+1)。get1(i,i+1,val)计算的是第i个0和第i+1个0之间有多少个小于0的r值绝对值≤val,get2(i,i+1,val)计算的是第i个0和第i+1个0之间有多少个大于0的r值≤val。
我们可以先把正数都保存在二元组数组ypos中,存(r[i],i);把负数保存在二元组数组xpos中,存(−r[i],i)。提前将位于两个相邻0之间的段按照第一关键字排序,这样在状态转移的时候就可以通过二分来加速了。
复杂度分析
时间复杂度
状态数量是O(m2),单次转移的时间复杂度为O(log2n)。对ypos和xpos分段排序,最终也相当于对O(n)规模的数组排序,时间复杂度为O(nlog2n)。因此,整个算法的时间复杂度为O(nlog2n+m2log2n)。
空间复杂度
DP
数组的空间复杂度为O(m2)。为了加速转移,预处理出两个O(m2)规模的二元组数组xmp和ymp,xmp[i][i+1]表示第i个0和第i+1个0之间的负r值在xpos数组上是哪个子区间(low,high),ymp[i][i+1]表示第i个0和第i+1个0之间的正r值在ypos数组上是哪个子区间(low,high)。因此,整个算法的额外空间复杂度为O(m2)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 2000010, M = 5001;
int n, m, r[N], dp[M][M];
array<int, 2> xmp[M][M], ymp[M][M];
int main() {
scanf("%d%d", &n, &m);
vector<array<int, 2>> xpos, ypos;
vector<int> zpos;
for(int i = 1; i <= n; i++) {
scanf("%d", &r[i]);
if(r[i] < 0) {
xpos.push_back({-r[i], i});
}else if(r[i] > 0) {
ypos.push_back({r[i], i});
}else {
zpos.push_back(i);
}
}
zpos.push_back(n + 1);
for(int i = 1; i < zpos.size(); i++) {
int cur = zpos[i - 1], nxt = zpos[i];
int l = 0, r = xpos.size() - 1, index = -1;
while(l <= r) {
int mid = l + r >> 1;
if(xpos[mid][1] > cur) {
index = mid;
r = mid - 1;
}else {
l = mid + 1;
}
}
int low = index;
l = 0, r = xpos.size() - 1, index = -1;
while(l <= r) {
int mid = l + r >> 1;
if(xpos[mid][1] < nxt) {
index = mid;
l = mid + 1;
}else {
r = mid - 1;
}
}
int high = index;
xmp[i - 1][i] = {low, high};
if(low != -1 && high != -1) {
sort(xpos.begin() + low, xpos.begin() + high + 1);
}
}
for(int i = 1; i < zpos.size(); i++) {
int cur = zpos[i - 1], nxt = zpos[i];
int l = 0, r = ypos.size() - 1, index = -1;
while(l <= r) {
int mid = l + r >> 1;
if(ypos[mid][1] > cur) {
index = mid;
r = mid - 1;
}else {
l = mid + 1;
}
}
int low = index;
l = 0, r = ypos.size() - 1, index = -1;
while(l <= r) {
int mid = l + r >> 1;
if(ypos[mid][1] < nxt) {
index = mid;
l = mid + 1;
}else {
r = mid - 1;
}
}
int high = index;
ymp[i - 1][i] = {low, high};
if(low != -1 && high != -1) {
sort(ypos.begin() + low, ypos.begin() + high + 1);
}
}
function<int(int, int, int, int)> get = [&](int cur, int nxt, int x, int flag) {
auto& pir = flag? xmp[cur][nxt]: ymp[cur][nxt];
int l = pir[0], r = pir[1], index = r + 1;
if(l == -1 || r == -1) return 0;
while(l <= r) {
int mid = l + r >> 1;
if((flag? xpos[mid][0]: ypos[mid][0]) > x) {
index = mid;
r = mid - 1;
}else {
l = mid + 1;
}
}
return index - pir[0];
};
for(int x = 0; x <= m; x++) {
dp[m][x] = 0;
}
for(int i = zpos.size() - 2; i >= 0; i--) {
int cur = zpos[i], nxt = zpos[i + 1];
for(int x = 0; x <= i; x++) {
int y = i - x;
int cnt1 = get(i, i + 1, x + 1, 1) + get(i, i + 1, y, 0);
dp[i][x] = max(dp[i][x], cnt1 + dp[i + 1][x + 1]);
int cnt2 = get(i, i + 1, x, 1) + get(i, i + 1, y + 1, 0);
dp[i][x] = max(dp[i][x], cnt2 + dp[i + 1][x]);
}
}
printf("%d\n", dp[0][0]);
return 0;
}