方法1:树状数组
分析
-
对于最原始的树状数组存在两个操作:单点加,求区间和(即a[x]+=c, query[L~R]);
-
AcWing 242. 一个简单的整数问题的操作正好反过来:区间加,求单点和(即a[L~R]+=c, query[x]);
-
本题的操作更进一步:区间加,求区间和(即a[L~R]+=c, query[L~R]);
-
对于区间加,我们仍然可以使用差分的思想,将原数组a转化成差分数组b,则:
(1)$a[L,R]+=c \iff b[L]+=c, b[R+1]-=c$;
(2)$a[x] = \sum b[i], 1 \le i \le x$;
- 如何将原数组a转换为差分数组呢?转化过程如下,这里必须要求数据从a[1]开始,a[0]=0:
$$ b[1] = a[1] - a[0] \\ b[2] = a[2] - a[1] \\ … \\ b[n] = a[n] - a[n - 1] $$
- 这样,对于数组a的区间加法可以转化成对数组b的单点加;但是对数组a求区间和我们就要考虑一下如何求解了。
- 求数组a的区间和,只需要求出数组a的前缀和即可,即求出:
$$ \sum_{i=1}^{x} a_i $$
又因为:$a[i] = \sum b[j], 1 \le j \le i$,所以有:
$$
\sum_{i=1}^{x} a_i = \sum_{i=1}^{x} \sum_{j=1}^{i} b_i = (b_1)+(b_1+b_2)+…+(b_1+b_2+…+b_x)
$$
如下图(蓝色的是我们需要求解的部分,红色的是我们补上的内容,则蓝色和=全部和-红色和):
则有:
$$
\sum_{i=1}^{x} a_i = \Bigl(\sum_{i=1}^{x} b_i \Bigr) \times (x+1) - (b1 + 2 \times b_2 + … + x \times b_x)
$$
因此,我们在操作的同时维护两个前缀和即可,分别是:$\sum b_i$和$\sum i \times b_i$。
- 另外数组a中的数据最大为$10^9$,操作次数为$10^5$,每次最大加上1000,因为是区间和,所以可能会超过int的范围,因此需要使用long long存储结果。
#include <iostream>
using namespace std;
typedef long long LL;
const int N = 100010;
int n, m; // 数列长度、操作个数
int a[N]; // 原数组
LL tr1[N]; // a对应的差分数组b的前缀和
LL tr2[N]; // 维护b[i] * i的前缀和
int lowbit(int x) {
return x & -x;
}
void add(LL tr[], int x, LL c) {
for (int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}
LL sum(LL tr[], int x) {
LL res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
LL prefix_sum(int x) {
return sum(tr1, x) * (x + 1) - sum(tr2, x);
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
// 使用差分数组b对树状数组初始化
for (int i = 1; i <= n; i++) {
int b = a[i] - a[i - 1];
add(tr1, i, b);
add(tr2, i, (LL)b * i);
}
while (m--) {
char op[2];
int l, r, d;
scanf("%s%d%d", op, &l, &r);
if (*op == 'C') {
scanf("%d", &d);
// b[l] += d, tr2维护的是b[i] * i的前缀和
// 因此b[l]增加d, 则(b[l] + d) * l增加了l*d
add(tr1, l, d), add(tr2, l, l * d);
// b[r + 1] -= d
add(tr1, r + 1, -d), add(tr2, r + 1, (r + 1) * -d);
} else {
printf("%lld\n", prefix_sum(r) - prefix_sum(l - 1));
}
}
return 0;
}
方法2:线段树
分析
-
本题对应的是区间加,区间查询问题,可以转化为单点加,区间查询的问题,具体可以参考:树状数组。这里使用线段树解决这个问题。
-
本题需要用到线段树五个操作中最复杂的一个,即
pushdown
:把当前父节点的修改信息下传到子节点,也被称为懒标记(延迟标记)。 -
对于区间修改,最坏的情况下,时间复杂度是$O(n)$的,比如将整个区间修改,这是我们不能接受的,因此pushdown操作应运而生。其核心思想是懒标记,即当树中某个区间已经完全被我们修改的区间包含了,就不再递归下去,直接返回,同时在该节点标记上需要加上一个数。对于本题来说,下面是懒标记的具体用法。
struct Node {
int l, r; // 区间左右端点
int sum; // 如果考虑当前节点及子节点上的所有标记,其区间[l, r]的总和就是sum
int add; // 懒标记,表示需要给以当前节点为根的子树中的每一个节点都加上add这个数(不包含当前节点)
}
-
通过这样的操作,修改的时间复杂度也变成了$O(log(n))$了。
-
这样做之后,我们的查询操作(
query
)也要跟着变化,如下图:
这个操作对应到代码上是(当前节点是root,左孩子是left,右孩子是right):
void pushdown(int u) {
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if (root.add) {
left.add += root.add, left.sum += (LL)(left.r - left.l + 1) * root.add;
right.add += root.add, right.sum += (LL)(right.r - right.l + 1) * root.add;
root.add = 0;
}
}
- 修改(
modify
)操作,如果当前考察的整个区间都要加上一个数,则可以直接加上,就不需要进行pushdown
操作了;否则也要进行类似于上面的pushdown
操作。
void modify(int u, int l, int r, int d) {
if (tr[u].l >= l && tr[u].r <= r) { // 当前节点对应的区间完全在[l, r]之间
tr[u].sum += (LL)(tr[u].r - tr[u].l + 1) * d;
tr[u].add += d;
} else { // 一定要分裂
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, d);
if (r > mid) modify(u << 1 | 1, l, r, d);
pushup(u);
}
}
代码
#include <iostream>
using namespace std;
typedef long long LL;
const int N = 100010;
int n, m; // 数列长度、操作个数
int a[N]; // 输入的数组
struct Node {
int l, r;
LL sum; // 如果考虑当前节点及子节点上的所有标记,其区间[l, r]的总和就是sum
LL add; // 懒标记,表示需要给以当前节点为根的子树中的每一个节点都加上add这个数(不包含当前节点)
} tr[N * 4];
// 由子节点的信息,来计算父节点的信息
void pushup(int u) {
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
// 把当前父节点的修改信息下传到子节点,也被称为懒标记(延迟标记)
void pushdown(int u) {
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if (root.add) {
left.add += root.add, left.sum += (LL)(left.r - left.l + 1) * root.add;
right.add += root.add, right.sum += (LL)(right.r - right.l + 1) * root.add;
root.add = 0;
}
}
// 创建线段树
void build(int u, int l, int r) {
if (l == r) tr[u] = {l, r, a[l], 0};
else {
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
// 将a[l~r]都加上d
void modify(int u, int l, int r, LL d) {
if (tr[u].l >= l && tr[u].r <= r) {
tr[u].sum += (LL)(tr[u].r - tr[u].l + 1) * d;
tr[u].add += d;
} else { // 一定要分裂
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, d);
if (r > mid) modify(u << 1 | 1, l, r, d);
pushup(u);
}
}
// 返回a[l~r]元素之和
LL query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
LL sum = 0;
if (l <= mid) sum += query(u << 1, l, r);
if (r > mid) sum += query(u << 1 | 1, l, r);
return sum;
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
build(1, 1, n);
char op[2];
int l, r, d;
while (m--) {
scanf("%s%d%d", op, &l, &r);
if (*op == 'C') {
scanf("%d", &d);
modify(1, l, r, d);
} else {
printf("%lld\n", query(1, l, r));
}
}
return 0;
}