分析
- 考虑线段树中的每个节点存储什么信息,如下:
struct Node {
int l, r;
int sum; // 区间总和(已经对p取模)
int add, mul; // 懒标记,表示当前节点的子节点的sum值都需要进行sum * mul + add的运算
}
- 无论是加法还是乘法,我们统一转化为$t \times a + b$的形式
(1)如果题目是乘以一个数k,则a = k, b = 0;
(2)如果是加上一个数k,则a = 1, b = k;
- 对于$t \times a + b$进行同样的操作,即$(t \times a + b) \times c + d$,则可以变为$t \times a \times c + b \times c + d$,所$mul=a \times c, add=b \times c + d$。
#include <iostream>
using namespace std;
typedef long long LL;
const int N = 100010;
int n, p, m; // 元素个数,取模的数,操作个数
int a[N];
struct Node {
int l, r;
int sum, add, mul;
} tr[N * 4];
void pushup(int u) {
tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % p;
}
void build(int u, int l, int r) {
if (l == r) tr[u] = {l, r, a[l] % p, 0, 0};
else {
tr[u] = {l, r, 0, 0, 1}; // sum值会在pushup时被更新
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
// 对节点t进行 (t * mul + add) 的操作
void eval(Node &t, int add, int mul) {
// sum = sum * mul + len * add;(len为区间长度)
t.sum = ((LL)t.sum * mul + (LL)(t.r - t.l + 1) * add) % p;
t.mul = (LL)t.mul * mul % p;
t.add = ((LL)t.add * mul + add) % p;
}
void pushdown(int u) {
eval(tr[u << 1], tr[u].add, tr[u].mul);
eval(tr[u << 1 | 1], tr[u].add, tr[u].mul);
tr[u].add = 0, tr[u].mul = 1;
}
void modify(int u, int l, int r, int add, int mul) {
if (tr[u].l >= l && tr[u].r <= r) eval(tr[u], add, mul);
else {
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, add, mul);
if (r > mid) modify(u << 1 | 1, l, r, add ,mul);
pushup(u);
}
}
int query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
int sum = 0;
if (l <= mid) sum = query(u << 1, l, r);
if (r > mid) sum = (sum + query(u << 1 | 1, l, r)) % p;
return sum;
}
int main() {
scanf("%d%d", &n, &p);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
build(1, 1, n);
scanf("%d", &m);
while (m--) {
int t, l, r, d;
scanf("%d%d%d", &t, &l, &r);
if (t == 1) {
scanf("%d", &d);
modify(1, l, r, 0, d); // *d + 0
} else if (t == 2) {
scanf("%d", &d);
modify(1, l, r, d, 1); // *1 + d
} else {
printf("%d\n", query(1, l, r));
}
}
return 0;
}