D. Optimal Partition
题意:将一段区间分成若干段,分成的段分三种情况,贡献不一样,求怎么样切使得数组的总贡献最大。
思路: (DP + 权值线线段树优化)
可以发现非常容易推出状态转移方程
f[i]:前i个位置中且在第i个位置切一刀的集合
属性:max
$$ if(s[i] - s[j] >= 0) \\ f[i] = max(f[j] + i - j) $$
$$ if(s[i] - s[j] < 0) \\ f[i] = max(f[j] - (i - j))$$
可以发现为O(n ^ 2)的DP, 铁TLE
化简式子得到
$(s[i] > s[j]) f[i] - i = max(f[j] - j) $
$(s[i] == s[j]) f[i] = max(f[j]) $
$(s[i] < s[j]) f[i] + i = max(f[j] + j) $
可以发现只有s[j] < s[i]的时候才会有 第一个方程,当前条件只由前缀和限制(2 3等价)
发现可以利用权值来构建一颗线段树来维护三个属性, x0为f[i] - i, x1为f[i], x2为f[i] + i
涉及到了求区间最值和单点修改。
权值线段树,实际上就是按照值域进行划分区间的普通线段树,所以一般需要进行离散化。
modify(1, pos[0], pos[0], 0, 0);//当前位置相当于DP的初始化 这个DP相当于由线段树来维护上一步最优解
除此之外还要将0也放入离散化数组,因为前缀和有可能出现负值
code
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <vector>
#include <map>
#include <set>
#include <queue>
#include <deque>
#include <stack>
#include <bitset>
#include <unordered_map>
#define IOS ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define eb push_back()
using namespace std;
typedef long long ll;
const int INF = 0x3f3f3f3f;
const int maxn = 5e5 + 10;
const ll MAXN = -2e18;
typedef pair <int, int> PII;
ll a[maxn], sum[maxn], f[maxn];
int pos[maxn];
struct note{
int l;
int r;
ll x0;
ll x1;
ll x2;
int mid()
{
return (l + r) >> 1;
}
}tre[maxn << 2];
void pushup(int rt)
{
tre[rt].x0 = max(tre[rt * 2].x0, tre[rt * 2 + 1].x0);
tre[rt].x1 = max(tre[rt * 2].x1, tre[rt * 2 + 1].x1);
tre[rt].x2 = max(tre[rt * 2].x2, tre[rt * 2 + 1].x2);
}
void build(int rt, int l, int r)
{
tre[rt] = {l, r};
if(tre[rt].l == tre[rt].r)
{
tre[rt].x0 = tre[rt].x1 = tre[rt].x2 = MAXN;
return ;
}
int mid = tre[rt].mid();
build(rt * 2, l, mid), build(rt * 2 + 1, mid + 1, r);
pushup(rt);
}
void modify(int rt, int l , int r, ll x, ll t)
{
if(tre[rt].l == l && tre[rt].r == r)
{
tre[rt].x0 = max(tre[rt].x0, x - t);
tre[rt].x1 = max(tre[rt].x1, x);
tre[rt].x2 = max(tre[rt].x2, x + t);
return ;
}
int mid = tre[rt].mid();
if(l <= mid) modify(rt * 2, l, r, x, t);
else modify(rt * 2 + 1, l, r, x, t);
pushup(rt);
}
note query(int rt, int l, int r)
{
if(tre[rt].l == l && tre[rt].r == r)
{
return tre[rt];
}
int mid = tre[rt].mid();
if(r <= mid) return query(rt * 2, l, r);
else if(l > mid) return query(rt * 2 + 1, l, r);
else
{
note res, temp1, temp2;
temp1 = query(rt * 2, l, mid), temp2 = query(rt * 2 + 1, mid + 1, r);
res.x0 = max(temp1.x0, temp2.x0);
res.x1 = max(temp1.x1, temp2.x1);
res.x2 = max(temp1.x2, temp2.x2);
return res;
}
}
int main()
{
int T;
scanf("%d", &T);
while(T --)
{
int n;
scanf("%d", &n);
for(int i = 1 ; i <= n ; i ++)
scanf("%lld", &a[i]), sum[i] = sum[i - 1] + a[i];
vector <ll> alls;
for(int i = 0 ; i <= n ; i ++)
alls.push_back(sum[i]);
sort(alls.begin(), alls.end()); //哈希掉
alls.erase(unique(alls.begin(), alls.end()), alls.end());
int r = alls.size(); // [0--- r - 1]
build(1, 1, r);
for(int i = 0 ; i <= n ; i ++)
{
pos[i] = lower_bound(alls.begin(), alls.end(), sum[i]) - alls.begin() + 1;
}
for(int i = 0 ; i <= n ; i ++)
f[i] = MAXN;
f[0] = 0;
modify(1, pos[0], pos[0], 0, 0);//当前位置相当于DP的初始化
for(int i = 1 ; i <= n ; i ++)
{
if(pos[i] - 1 >= 1)
f[i] = max(f[i], query(1, 1, pos[i] - 1).x0 + i * 1ll);
f[i] = max(f[i], query(1, pos[i], pos[i]).x1);
if(pos[i] + 1 <= r)
f[i] = max(f[i], query(1, pos[i] + 1, r).x2 - i * 1ll);
modify(1, pos[i], pos[i], f[i], i * 1ll);
}
printf("%lld\n", f[n]);
}
return 0;
}