首先考虑如果没有 b 这个序列,只需要维护如下操作:
- ax=v 单点改。
- 求 [l,r] 区间前缀最大值之和,即 r∑i=lmax。
事实上是好做的。
在每个节点维护区间最大值、全歼前缀最大值之和,即 \max,suma。
定义 ask(u,v)
表示 u 节点区间之前的最大值为 v,求 u 内前缀最大值之和。
那么 pushup
就好写了:u.suma = ls.suma + ask(rs, ls.max)
。
至于 ask
函数,需要分类讨论:
- 若 v \geq ls.\max:则答案为 v \times ls.siz + ask(rs, v)。
- 若 v \lt ls.\max:答案为 ask(ls, v) + ask(rs, ls.\max) = ask(ls, v) + u.suma - ls.suma。
- 这样就成了单侧递归,保证复杂度。
其中 ask
函数复杂度 O(\log n),那么 pushup
也是 O(\log n),单点修改复杂度为 O(\log ^2 n)。
然后考虑加入 b 序列怎么做。
对于 a 序列的维护与上文同理。
考虑定义 update(u,v,d)
函数,表示 u 节点区间之前最大值为 v,将 u 区间内进行 b 序列操作,执行 d 次(因为可能多次操作,这里在下传标记)。
同样也要分类讨论:
- 首先是一些基本的 Corner Case:
- 若 v \geq u.\max,即覆盖全区间:
- 那么对于区间内所有 b_i \leftarrow b_i + v \times d。
- 若已经到叶子节点:
- 则 b_i \leftarrow b_i + \max(v,a_i) \times d。
- 若 v \geq u.\max,即覆盖全区间:
- 若 v \geq ls.\max:
- 对于左儿子所有 b_i \leftarrow b_i + v \times d,这只要打一个区间加常数标记即可维护。
- 对于右儿子不确定 v 作为前缀 \max 延续到哪里,因此要递归处理。
- 若 v \lt ls.\max:
- 对于左儿子不确定 v 作为前缀 \max 延续到哪里,因此要递归处理。
- 对于右儿子相当于要执行
update(rs, ls.max, d)
,但是直接递归下去复杂度爆表。 - 因此只能打标记,由于这种双侧都要递归的情况只会出现在节点的右儿子上,可以考虑在 u 上直接打标记。
- 即
flag2
表示 u 的右儿子需要执行update(rs, ls.max, flag2)
,在 pushdown 的时候再往右儿子下传即可。 - 这样的好处在于可以直接计算出
sumb
(答案):sumb \leftarrow sumb + (u.suma - l.suma) \times d, flag2 \leftarrow flag2 + d。
有一些细节:update 之后需要顺便返回新的 suma
,方便上层节点进行计算。
核心技巧是将双侧下传改为单侧下传。
另一侧通过打标记在 pushdown 的时候再执行下传操作,将 pushdown 的时间复杂度变为 O(\log n) 从而平衡时间复杂度。
或者通过已知信息快速计算贡献,也可以减少不必要的递归计算。
注意在 ask 函数内不能 pushdown,因为 ask 只查询 a 数组,pushdown 只影响 b 数组。
如果里面加上 pushdown,会导致 pushup 复杂度炸裂。
代码
#include <bits/stdc++.h>
#define it __int128
using namespace std;
const int N = 5e5 + 5;
int n, m, a[N];
struct Tree {
int l, r, len, Max;
it suma, sumb, flag1; int flag2; // flag1 是区间加常数;flag2 是右儿子标记,均针对 b 数组
} tr[N << 2];
inline void set_flag1(int u, it d) { tr[u].flag1 += d, tr[u].sumb += tr[u].len * d; }
it update(int u, int v, int d) {
// u 节点区间之前最大值为 v,将 u 区间内进行 b 序列操作,执行 d 次
// 返回区间 suma
if (v >= tr[u].Max) return set_flag1(u, (it)v * d), (it)v * tr[u].len; // 覆盖整个区间
if (tr[u].l == tr[u].r) return set_flag1(u, (it)max(tr[u].Max, v) * d), max(tr[u].Max, v); //叶子节点
it ret;
if (v >= tr[u << 1].Max) ret = update(u << 1, v, d) + update(u << 1 | 1, v, d); //左区间加常数懒标记 O(1)
else tr[u].flag2 += d, ret = update(u << 1, v, d) + tr[u].suma - tr[u << 1].suma;
tr[u].sumb += ret * d;
return ret;
}
inline void pushdown(int u) {
if (tr[u].flag1) set_flag1(u << 1, tr[u].flag1), set_flag1(u << 1 | 1, tr[u].flag1), tr[u].flag1 = 0;
if (tr[u].flag2) { it p = update(u << 1 | 1, tr[u << 1].Max, tr[u].flag2); tr[u].flag2 = 0; }
}
it ask(int u, int v) { //求 u 的前缀最大值之和,u 区间之前 Max = v
if (tr[u].l == tr[u].r) return max(v, tr[u].Max);
// pushdown(u); //!pushdown 只影响 b 数组,这里求 a 数组不能 pushdown,会影响复杂度。
if (v >= tr[u << 1].Max) return (it)v * tr[u << 1].len + ask(u << 1 | 1, v);
else return ask(u << 1, v) + (tr[u].suma - tr[u << 1].suma);
}
inline void pushupa(int u) {
tr[u].Max = max(tr[u << 1].Max, tr[u << 1 | 1].Max);
tr[u].suma = tr[u << 1].suma + ask(u << 1 | 1, tr[u << 1].Max);
}
inline void pushupb(int u) { tr[u].sumb = tr[u << 1].sumb + tr[u << 1 | 1].sumb; }
void build(int u, int l, int r) {
tr[u].l = l, tr[u].r = r, tr[u].len = r - l + 1;
if (l == r) return tr[u].Max = tr[u].suma = a[l], void();
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushupa(u);
}
void change_a(int u, int x, int d) {
if (tr[u].l == tr[u].r) return tr[u].Max = d, void();
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) change_a(u << 1, x, d);
else change_a(u << 1 | 1, x, d);
pushupa(u);
}
int change_b(int u, int l, int r, int v) { // 返回最大值
if (tr[u].l >= l && tr[u].r <= r) return update(u, v, 1), max(v, tr[u].Max);
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) v = max( v, change_b(u << 1, l, r, v) );
if (r > mid) v = max( v, change_b(u << 1 | 1, l, r, v) );
pushupb(u);
return v;
}
it query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sumb;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1; it res = 0;
if (l <= mid) res = query(u << 1, l, r);
if (r > mid) res += query(u << 1 | 1, l, r);
return res;
}
void print(it res) {
if (res > 9) print(res / 10);
putchar( (res % 10) ^ 48 );
if (res <= 9) return;
}
int opt, A, B;
signed main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
build(1, 1, n);
while (m--) {
scanf("%d%d%d", &opt, &A, &B);
if (opt == 1) change_a(1, A, B);
else if (opt == 2) it p = change_b(1, A, B, 0);
else print( query(1, A, B) ), putchar('\n');
}
return 0;
}