$\huge{树链剖分}$
$引入$
树链剖分用于解决在树上执行的操作
将树上操作变为区间操作,用区间来维护,通常用线段树
$操作$
1 x y z
,表示将树从 $x$ 到 $y$ 结点最短路径上所有节点的值都加上 $z$。2 x y
,表示求树从 $x$ 到 $y$ 结点最短路径上所有节点的值之和。3 x z
,表示将以 $x$ 为根节点的子树内所有节点值都加上 $z$。4 x
表示求以 $x$ 为根节点的子树内所有节点值之和。
$概念$
- 重儿子:对于每一个非叶子节点,它的儿子中子树节点数量最多的那一个儿子为该节点的重儿子。
- 轻儿子:对于每一个非叶子节点,它的儿子中非重儿子的剩下所有儿子即为轻儿子。
- 叶子节点没有重儿子也没有轻儿子。
- 重边:连接任意两个重儿子的边叫做重边。
- 轻边:剩下的即为轻边。
- 重链:相邻重边连起来的 连接一条重儿子 的链叫重链。
- 对于叶子节点,若其为轻儿子,则有一条以自己为起点的长度为1的链。
- 每一条重链以轻儿子为起点。
- DFN序:从根节点开始递归,先弹入节点,再递归以重儿子为根的子树,最后递归其他子树
$处理$
在询问之前,首先要处理好节点的信息,分两个$dfs$完成。
$dfs1$
需要处理:
- 标记父节点信息
father
- 标记所在深度
depth
- 记录子树大小
sz
- 标记非叶子节点的重儿子
son
void dfs1(int u, int z, int depth)
{
sz[u] = 1, fa[u] = z, dep[u] = depth;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j != z)
{
dfs1(j, u, depth + 1);
sz[u] += sz[j];//标记子树大小
if (sz[j] > sz[son[u]])son[u] = j;//选举重儿子
}
}
}
$dfs2$
需要处理:
- 标记每个点的新编号
dfn序
- 处理每个点所在链的顶端
top
void dfs2(int u, int topf)
{
id[u] = ++ cnt, wt[cnt] = w[u];//标记每个点的新编号
top[u] = topf;//处理每个点所在链的顶端
if (!son[u])return;//叶子节点,返回
dfs2(son[u], topf);//先递归重儿子
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa[u] || j == son[u])continue;
dfs2(j, j);//每一个轻儿子都有一条从它开始的重链
}
}
$询问与维护$
$子树维护$
询问以$u$为根的子树上所有点的权值之和,
或修改$u$为根的子树上所有点的值。
在区间上长度为sz[u]
,区间为$[id[u],id[u]+sz[u]-1]$,用线段树维护。
void update_tree(int x, int z)//修改
{
update(1, id[x], id[x] + sz[x] - 1, z);
}
LL query_tree(int x)//询问
{
return query(1, id[x], id[x] + sz[x] - 1);
}
$路径维护$
树链剖分一般将树上任意两点的路径,划分成了不超过$logn$段的重链。
由于dfn序
中,重链上点的编号连续,所以每一条重链可以用线段树区间维护
每次将两点top
更低的点往上跳,类似lca
中的倍增向上法
void update_path(int x, int y, int z)//修改
{
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])swap(x, y);
update(1, id[top[x]], id[x], z);//区间维护
x = fa[top[x]];//向上跳
}
if (dep[x] < dep[y])swap(x, y);
update(1, id[y], id[x], z);//剩余部分维护
}
LL query_path(int x, int y)//询问
{
LL res = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])swap(x, y);
res += query(1, id[top[x]], id[x]);//区间维护
x = fa[top[x]];//向上跳
}
if (dep[x] < dep[y])swap(x, y);
res += query(1, id[y], id[x]);//剩余部分维护
return res;
}
$实现$
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 100010, M = N * 2;
int w[N], wt[N], id[N], cnt;
int sz[N], fa[N], dep[N], son[N], top[N];
int h[N], e[M], ne[M], idx;
struct Node {
int l, r;
LL sum, add;
}tr[N * 4];
void pushup(int u)//向上更新
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u)//下传标记
{
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if (root.add)
{
left.add += root.add, left.sum += (left.r - left.l + 1) * root.add;
right.add += root.add, right.sum += (right.r - right.l + 1) * root.add;
root.add = 0;
}
}
void add(int a, int b)//加边
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
void build(int u, int l, int r)//线段树建树
{
tr[u] = {l, r, wt[l]};
if (l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void update(int u, int l, int r, int k)//线段树修改
{
if (tr[u].l >= l && tr[u].r <= r)
{
tr[u].sum += (tr[u].r - tr[u].l + 1) * k;
tr[u].add += k;
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid)update(u << 1, l, r, k);
if (r > mid)update(u << 1 | 1, l, r, k);
pushup(u);
}
LL query(int u, int l, int r)//线段树查询
{
if (tr[u].l >= l && tr[u].r <= r)return tr[u].sum;
pushdown(u);
LL res = 0;
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid)res += query(u << 1, l, r);
if (r > mid)res += query(u << 1 | 1, l, r);
return res;
}
void dfs1(int u, int z, int depth)
{
sz[u] = 1, fa[u] = z, dep[u] = depth;//子树大小,父节点,节点深度
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j != z)
{
dfs1(j, u, depth + 1);
sz[u] += sz[j];//更新子树大小
if (sz[j] > sz[son[u]])son[u] = j;//更新重儿子
}
}
}
void dfs2(int u, int topf)
{
id[u] = ++ cnt, wt[cnt] = w[u];//给定新编号
top[u] = topf;//链顶节点
if (!son[u])return;
dfs2(son[u], topf);
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa[u] || j == son[u])continue;
dfs2(j, j);
}
}
void update_path(int x, int y, int z)
{
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])swap(x, y);
update(1, id[top[x]], id[x], z);
x = fa[top[x]];
}
if (dep[x] < dep[y])swap(x, y);
update(1, id[y], id[x], z);
}
LL query_path(int x, int y)
{
LL res = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])swap(x, y);
res += query(1, id[top[x]], id[x]);
x = fa[top[x]];
}
if (dep[x] < dep[y])swap(x, y);
res += query(1, id[y], id[x]);
return res;
}
void update_tree(int x, int z)
{
update(1, id[x], id[x] + sz[x] - 1, z);
}
LL query_tree(int x)
{
return query(1, id[x], id[x] + sz[x] - 1);
}
int main()
{
int n, m;
scanf("%d", &n);
for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]);
memset(h, -1, sizeof h);
for (int i = 1; i < n; i ++ )
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
dfs1(1, -1, 1), dfs2(1, 1);//处理
build(1, 1, n);
scanf("%d", &m);
while (m -- )
{
int opt;
scanf("%d", &opt);
if (opt == 1)
{
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
update_path(x, y, z);
}
if (opt == 2)
{
int x, y;
scanf("%d%d", &x, &y);
update_tree(x, y);
}
if (opt == 3)
{
int x, y;
scanf("%d%d", &x, &y);
printf("%lld\n", query_path(x, y));
}
if (opt == 4)
{
int x;
scanf("%d", &x);
printf("%lld\n", query_tree(x));
}
}
return 0;
}