K-单调问题
解题思路
本题要求用最小代价让整个序列变成 $k$ 个单调序列。这一题是 数字序列 的升级版,数字序列 求的是让一个序列变成单调序列的最小代价。这里可以用类似的做法来做。
由于要变成 $k$ 个序列,因此用 $dp$ 来求,首先需要预处理出将每一段区间变成单调需要的最小代价 $cost(i, j)$。本题的数据范围较小,直接枚举左端点 $l$,每次将 $l$ 开头的区间统一预处理,这里用 数字序列 的做法来求区间的最小代价,预处理每一段区间的时间复杂度是 $O(nlogn)$ 的,一共枚举 $n$ 个左端点,所以整个预处理的时间复杂度就是 $O(n^2logn)$ 的,不会超时。
在 数字序列 中我们只求了单调上升的最小代价,这里可以用一个小技巧,我们求完单调上升后,将整个序列取反,再求一次单调上升,而取反后的单调上升实际上就是取反之前的单调下降,这样就能用同一份代码求单调上升和单调下降的最小代价,再取一个最小值即可。
这里对于每个区间,我们能用 数字序列 构造出一个单调序列,使得原序列和单调序列的各项之差的绝对值之和最小,而各项之差的绝对值就是操作的最小代价,但是我们在 数字序列 中求出构造序列之后并不能快速得出最小代价是多少,还需要枚举一下去计算,这里我们要想快速计算出来,可以对于每个区间,我们需要维护 $sum$:整个区间的和,$cnt$:整个区间的数的个数,$sum’$:较小一半数的和,$cnt’$:较小一半数的个数,$root$:对应左偏树的根节点。这里我们可以用 $root$ 直接得到中位数 $mid$,然后我们就可以通过以上信息 $O(1)$ 计算出这个区间的最小代价为 $mid \times cnt’ - sum’ + sum - sum’ - (cnt - cnt’) \times mid$,这样每次求出单调序列时方便我们计算每个区间的最小代价。
预处理完代价就可以用 $dp$ 来做了,设 $f(i, j)$ 表示将 $a_1 \sim a_i$ 分成 $j$ 个单调区间的最小代价。决策就是枚举第 $j$ 个区间的长度进行转移,得出状态转移方程
$$ f(i, j) = min \lbrace f(i - k, j - 1) + cost(i - k + 1, i) \rbrace~~k \in [1, i] $$
C++ 代码
#include <iostream>
#include <cstring>
using namespace std;
const int N = 1010;
int n, m;
//cost[i][j] 表示将 (i, j) 变成单调区间的最小代价
//f[i][j] 表示将 a[1] ~ a[i] 分成 j 个单调区间的最小代价
int f[N][11], cost[N][N];
int a[N]; //原序列
int v[N], dist[N], l[N], r[N], idx; //左偏树
struct Segment
{
int root; //根节点
int tot_sum, tot_cnt; //总的和、总的大小
int tree_sum, tree_cnt; //较小一半的和、较小一半的大小
int get_cost() //计算当前区间的最小代价
{
int mid = v[root];
return mid * tree_cnt - tree_sum + tot_sum - tree_sum - mid * (tot_cnt - tree_cnt);
}
}stk[N]; //存储所有区间
int tt;
int merge(int x, int y) //将 x 和 y 所在的左偏树合并,并返回根节点
{
if(!x || !y) return x + y;
if(v[x] < v[y]) swap(x, y);
r[x] = merge(r[x], y);
if(dist[r[x]] > dist[l[x]]) swap(r[x], l[x]);
dist[x] = dist[r[x]] + 1;
return x;
}
int pop(int x) //将以 x 为根的左偏树的根节点删除
{
return merge(l[x], r[x]);
}
void get_cost(int u) //计算以 u 为左端点的所有区间的最小代价
{
tt = 0;
int res = 0; //记录当前的代价
for(int i = u; i <= n; i++)
{
auto cur = Segment({i, v[i], 1, v[i], 1});
l[i] = r[i] = 0, dist[i] = 1;
while(tt && v[stk[tt].root] > v[cur.root])
{
res -= stk[tt].get_cost(); //合并前先将记录的代价删掉
bool is_pop = cur.tot_cnt % 2 && stk[tt].tot_cnt % 2; //记录合并后是否需要删除堆顶元素
//更新合并后的信息
cur.root = merge(cur.root, stk[tt].root);
cur.tot_cnt += stk[tt].tot_cnt;
cur.tot_sum += stk[tt].tot_sum;
cur.tree_cnt += stk[tt].tree_cnt;
cur.tree_sum += stk[tt].tree_sum;
if(is_pop) //如果需要删除堆顶元素
{
cur.tree_cnt--;
cur.tree_sum -= v[cur.root];
cur.root = pop(cur.root);
}
tt--;
}
stk[++tt] = cur;
res += cur.get_cost();
cost[u][i] = min(cost[u][i], res);
}
}
int main()
{
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
//预处理 cost
memset(cost, 0x3f, sizeof cost);
for(int i = 1; i <= n; i++) v[i] = a[i] - i; //将求严格单调递增变成求非严格单调递增
for(int i = 1; i <= n; i++) get_cost(i); //计算以 i 为左端点的所有递增区间的代价
for(int i = 1; i <= n; i++) v[i] = -a[i] - i; //将所有数取反,计算单调递减
for(int i = 1; i <= n; i++) get_cost(i); //计算以 i 为左端点的所有递减区间的代价
//初始化
memset(f, 0x3f, sizeof f);
f[0][0] = 0;
for(int i = 1; i <= n; i++)
for(int j = 1; j <= m; j++)
for(int k = 1; k <= i; k++)
f[i][j] = min(f[i][j], f[i - k][j - 1] + cost[i - k + 1][i]);
printf("%d\n", f[n][m]);
return 0;
}
bixu zhichi