树链剖分问题
解题思路
本题有四个操作,一个是令树上某一段路径中的每个数加上一个数,一个是令树上某一棵子树中的每个数加上一个数,一个是查询树上某一段路径中的所有数的和,一个是查询树上某一棵子树中的所有数的和。
可以发现本题都是对树上路径进行操作,因此我们可以用树链剖分来做,对于树上某一段路径,我们就能用树链剖分拆分成 $O(logn)$ 段连续区间,因此对于树上某一段路径的操作就能转化成对某 $O(logn)$ 段连续区间进行操作。
而另外两个操作是对树上某一棵子树进行修改、查询,而树链剖分是将树上某一段路径转化成由 $dfs$ 序得到的序列中的一些连续区间。而对于 $dfs$ 序又有一个特性,对于任意一个子树中的节点编号是连续的,因此对于任意一棵子树能转化成序列中的某一段区间,同样变成了一个区间操作。
综上所述我们就将四个树上操作都转化成了区间操作,而要实现的区间操作就是将某一段区间每个数加上一个数,以及查询某一段区间中所有数的和,这两个操作用树状数组、线段树等等都能实现。这里比较常用的线段树来做,而在线段树中就是实现区间修改和区间查询两个操作,需要搭配懒标记来实现。
C++ 代码
#include <iostream>
#include <cstring>
using namespace std;
typedef long long LL;
const int N = 100010, M = N * 2;
struct Node
{
int l, r;
//sum 表示 [l, r] 的权值和,add 表示当前区间的子区间中每个点需要加上的数(懒标记)
LL sum, add;
}tr[N * 4];
int n, m;
int h[N], w[N], e[M], ne[M], idx;
//id[i] 表示节点 i 的 dfs 序中的编号
//nw[i] 表示 dfs 序中第 i 个点的权值
int id[N], nw[N], timestamp;
int d[N]; //d[i] 表示节点 i 的深度
int cnt[N]; //cnt[i] 表示以节点 i 为根节点的子树中的节点个数
int top[N]; //top[i] 表示节点 i 所在重链的顶点
int fa[N]; //fa[i] 表示节点 i 的父节点
int son[N]; //son[i] 表示节点 i 的重儿子
void add(int a, int b) //添加边
{
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs1(int u, int father, int depth) //预处理 dep[], fa[], cnt[], son[]
{
d[u] = depth, fa[u] = father, cnt[u] = 1;
for(int i = h[u]; i != -1; i = ne[i])
{
int j = e[i];
if(j == father) continue;
dfs1(j, u, depth + 1);
cnt[u] += cnt[j];
if(cnt[son[u]] < cnt[j]) son[u] = j;
}
}
//t 表示当前重链的顶点
void dfs2(int u, int t) //找出 dfs 序和所有重链
{
id[u] = ++timestamp, nw[timestamp] = w[u], top[u] = t;
if(!son[u]) return; //没有儿子直接返回
dfs2(son[u], t); //否则优先搜索重儿子
for(int i = h[u]; i != -1; i = ne[i])
{
int j = e[i];
if(j == fa[u] || j == son[u]) continue;
dfs2(j, j); //轻儿子所在重链的顶点就是他自己
}
}
void pushup(int u) //用子节点的信息更新当前节点的信息
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u) //将当前节点的懒标记下传
{
Node &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if(root.add)
{
left.add += root.add, left.sum += root.add * (left.r - left.l + 1);
right.add += root.add, right.sum += root.add * (right.r - right.l + 1);
root.add = 0;
}
}
void build(int u, int l, int r) //初始化线段树
{
tr[u] = {l, r, nw[r], 0};
if(l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int l, int r, int k) //区间修改
{
if(tr[u].l >= l && tr[u].r <= r)
{
tr[u].add += k;
tr[u].sum += k * (tr[u].r - tr[u].l + 1);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) modify(u << 1, l, r, k);
if(r > mid) modify(u << 1 | 1, l, r, k);
pushup(u);
}
LL query(int u, int l, int r) //区间查询
{
if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
LL res = 0;
if(l <= mid) res += query(u << 1, l, r);
if(r > mid) res += query(u << 1 | 1, l, r);
return res;
}
void update_path(int u, int v, int k) //树上路径修改
{
while(top[u] != top[v]) //只要 u 和 v 不再同一个重链中,就能继续拆出重链
{
if(d[top[u]] < d[top[v]]) swap(u, v); //保证 top[u] 比 top[v] 更靠下
modify(1, id[top[u]], id[u], k); //此时就能将 u 所在的重链拆出来单独计算
u = fa[top[u]]; //将 u 跳到更上面的重链中
}
if(d[u] < d[v]) swap(u, v); //保证 u 比 v 更靠下
modify(1, id[v], id[u], k); //将拆出的最后一段重链单独计算
}
LL query_path(int u, int v) //树上路径查询
{
LL res = 0; //记录答案
while(top[u] != top[v]) //只要 u 和 v 不再同一个重链中,就能继续拆出重链
{
if(d[top[u]] < d[top[v]]) swap(u, v); //保证 top[u] 比 top[v] 更靠下
res += query(1, id[top[u]], id[u]); //此时就能将 u 所在的重链拆出来单独计算
u = fa[top[u]]; //将 u 跳到更上面的重链中
}
if(d[u] < d[v]) swap(u, v); //保证 u 比 v 更靠下
res += query(1, id[v], id[u]); //将拆出的最后一段重链单独计算
return res;
}
void update_tree(int u, int k) //树上子树修改
{
modify(1, id[u], id[u] + cnt[u] - 1, k); //转化成区间修改
}
LL query_tree(int u) //树上子树查询
{
return query(1, id[u], id[u] + cnt[u] - 1); //转化成区间查询
}
int main()
{
scanf("%d", &n);
for(int i = 1; i <= n; i++) scanf("%d", &w[i]);
memset(h, -1, sizeof h); //初始化邻接表
for(int i = 0; i < n - 1; i++)
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a); //无向边
}
dfs1(1, -1, 1); //预处理 dep[], fa[], cnt[], son[]
dfs2(1, 1); //找出 dfs 序和所有重链
build(1, 1, n); //初始化线段树
scanf("%d", &m);
while(m--)
{
int op, u, v, k;
scanf("%d", &op);
if(op == 1) //树上路径修改
{
scanf("%d%d%d", &u, &v, &k);
update_path(u, v, k);
}
else if(op == 2) //树上子树修改
{
scanf("%d%d", &u, &k);
update_tree(u, k);
}
else if(op == 3) //树上路径查询
{
scanf("%d%d", &u, &v);
printf("%lld\n", query_path(u, v));
}
else //树上子树查询
{
scanf("%d", &u);
printf("%lld\n", query_tree(u));
}
}
return 0;
}