给定一棵有根树,若节点 $z$ 既是节点 $x$ 的祖先,也是节点 $y$ 的祖先,则称 $z$ 是 $x,y$
的公共祖先。在 $x,y$ 的所有公共祖先中,深度最大的一个称为 $x,y$ 的最近公共祖先,记为 $LCA(x,y)$。
$LCA(x,y)$ 是 $x$ 到根的路径与 $y$ 到根的路径的交会点。它也是 $x$ 与 $y$ 之间的路径上深度最小的节点。求最近公共祖先的方法通常有三种:
向上标记法
从 $x$ 向上走到根节点,并标记所有经过的节点。
从 $y$ 向上走到根节点,当第一次遇到已标记的节点时,就找到了 $LCA(x,y)$。
对于每个询问,向上标记法的时间复杂度最坏为 $O(n)$。
树上倍增法
树上倍增法是一个很重要的算法。除了求 $LCA$ 之外,它在很多问题中都有广泛应用。
设 $F[x,k]$ 表示 $x$ 的 $2^k$ 辈祖先 ,即从 $x$ 向根节点走 $2^k$ 步到达的节点。特别地,若该节点不存在,则令 $F[x,k] = 0$。$F[x,0]$ 就是 $x$ 的父节点。除此之外,
任意的 $k∈[1, logn], F[x,k] = F[F[x,k - 1],k - 1]$ 。
这类似于一个动态规划的过程,“阶段” 就是节点的深度。因此,我们可以对树进行广度优先遍历,按照层次顺序,在节点入队之前,计算它在 $F$ 数组中相应的值。以上部分是预处理,时间复杂度为 $O(n logn)$,之后可以多次对不同的 $x,y$ 计算 $LCA$,每次询问的时间复杂度为 $O(logn)$。
基于 $F$ 数组计算 $LCA(x,y)$ 分为以下几步:
- 设 $d[x]$ 表示 $x$ 的深度。不妨设 $d[x]≥ d[y]$(否则可交换 $x,y$ )。
- 用二进制拆分思想,把 $x$ 向上调整到与 $y$ 同一深度。具体来说,就是依次尝试从 $x$ 向上走 $k = 2^{logn} , … , 2^1 , 2^0$ 步,检查到达的节点是否比 $y$ 深。在每次检查中,若是,则令 $x = F[x,k]$。
- 若此时 $x = y$,说明已经找到了 $LCA$,$LCA$ 就等于 $y$。
这就是上面的图中的第三种情况。
4:用二进制拆分思想,把 $x,y$ 同时向上调整,并保持深度一致且二者不相会。具体来说,就是依次尝试把 $x,y$ 同时向上走 $k = 2^{logn} , … , 2^1 , 2^0$ 步,在每次尝试中,若 $F[x,k]≠F[y,k]$(即仍未相会),则令 $x = F[x,k],y = F[y,k]$。
5:此时 $x,y$ 必定只差一步就相会了,它们的父节点 $F[x,0]$ 就是 $LCA$。
多次查询树上两点之间的距离,时间复杂度为 $O((n + m)logn)$。
// 多次查询树上两点之间的距离板子
// 树上倍增
const int N = 50010 , M = 2 * N;
int f[N][20] , d[N] , dist[N]; // d[]是 depth[], dist[] 存储该点到根节点的最短距离
int T ,t , n , m;
int h[N] , e[M] , ne[M] , w[M] , idx;
int q[N];
void add(int a, int b, int c){
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++;
}
// 预处理 f[x, k]不存在时=0; f[x, k] = f[f[x, k - 1] ,k - 1]
void bfs(int){
memset(dist , 0x3f , sizeof dist);
d[0] = 0; // 哨兵
d[root] = 1;
int hh = 0 , tt = 0;
q[0] = root;
while(hh <= tt)
{
int x = q[hh++];
for(int i = h[x] ; ~i ; i = ne[i])
{
int y = e[i];
if(d[y]) continue;
d[y] = d[x] + 1; // 深度 + 1
dist[y] = dist[x] + w[i];// 新距离
f[y][0] = x;
for(int j = 1 ; j <= t ; j++) // t = (int)(log(n) / log(2)) + 1;
f[y][j] = f[f[y][j-1]][j-1];
q[++tt] = y;
}
}
}
// 查询 x , y 的公共祖先
int lca(int x , int y){ // d[x] <= d[y]
if(d[x] > d[y]) swap(x, y);
// 将 y 向上调整 至与 x 同一深度
for(int k = t ; k >= 0 ; k--)
if(d[f[y][k]] >= d[x]) y = f[y][k];
if(x == y) return x;
// 将 x , y 同时向上调整,并且保证深度一致且不会相会
for(int k = t ; k >= 0 ; k--)
if(f[x][k] != f[y][k]){
x = f[x][k];
y = f[y][k];
}
return f[x][0];
}
int main()
{
cin >> T;
while(T--)
{
cin >> n >> m;
t = (int)(log(n) / log(2)) + 1;
memset(h, -1, sizeof h);
idx = 0;
memset(d , 0 , sizeof d);
// 读入一颗树树
for(int i = 1 ; i <= n ; i++){
int x, y, z;
scanf("%d%d%d" ,&x, &y ,&z);
add(x, y, z) , add(y, x, z);
}
bfs();
// 回答询问
for(int i = 1 ; <= m ; i++){
int x, y;
scanf("%d%d" , &x ,&y);
// x, y 两点之间的距离
printf("%d\n" , dist[x]+dist[y] - 2 * dist[lca(x, y)]);
}
}
return 0;
}
LCA 的Tarjan 算法
Tarjan 算法本质上是使用并查集对 “ 向上标记法” 的优化。它是一个离线算法,需要把 $m$ 个询问一次性读入,统一计算,最后统一输出。时间复杂度为 $O(n+m)$。在深度优先遍历的任意时刻,树中节点分为三类:
- 已经访问完毕并且回溯的节点。在这些节点上标记一个整数 $2$。
- 已经开始递归,但尚未回溯的节点。这些节点就是当前正在访问的节点 $x$ 以及 $x$ 的祖先。在这些节点上标记一个整数 $1$。
- 尚未访问的节点。这些节点没有标记。
对于正在访问的节点 $x$,它到根节点的路径已经标记为 $1$。若 $y$ 是已经访问完毕并且回溯的节点,则 $LCA(x,y)$ 就是从 $y$ 向上走到根,第一个遇到的标记为 $1$ 的节点。
可以利用并查集进行优化,当一个节点获得整数 $2$ 的标记时,把它所在的集合合并到它的父节点所在的集合中(合并时它的父节点标记一定为 $1$,且单独构成一个集合)。
这相当于每个完成回溯的节点都有一个指针指向它的父节点,只需查询所在集合的代表元素(并查集的 $find$ 操作),就等价于从 $y$ 向上一直走到一个开始递归但尚未回溯的节点(具有标记 $1$),即$LCA(x,y)$。
在 $x$ 回溯之前,标记情况与合并情况如下图所示。黑色表示标记为 $1$,灰色表示标记为 $2$,白色表示没有标记,箭头表示执行了合并操作。
此时扫描与 $x$ 相关的所有询问,若询问当中的另一个点 $y$ 的标记为 $2$,就知道了该询问的回答应该是 $y$ 在并查集中的代表元素(并查集中 $find(y)$ 函数的结果)。
多次查询树上两点之间的距离,时间复杂度为 $O(n+m)$。
合并操作是在回溯完当前节点之后进行合并
const int N = 50010 , M = N * 2;
int T, n , m , t;
int h[N] , e[M] , ne[M] ,w[M] ,idx;
int p[N] , dist[N] , v[N] , lca[N] , ans[N];
vector<int> query[N]; // 存 询问的另一节点
vector<int> query_id[N]; // 存 询问的编号
void add(int a, int b, int c){
e[idx] = b, ne[idx] = h[a] , w[idx] = c , h[a] = idx++
}
void add_query(int x, int y, int id){
query[x].emplace_back(y) , query_id[x].emplace_back(id);
query[y].emplace_back(x) , query_id[y].emplace_back(id);
}
int find(int x){
if(x != p[x]) p[x] = find(p[x]);
return p[x];
}
// 处理 x 节点到其他节点的距离
void tarjan(int x){
// 处理在 遍历的分支
v[x] = 1;
for(int i = h[x] ; ~i ; i = ne[i]){
int y = e[i];
if(v[y]) continue;
dist[y] = dist[x] + w[i];
tarjan(y);
p[y] = x;
}
// 计算距离
for(int i = 0 , i < query[x].size ; i++)
{
int y = query[x][i] , id = query_id[x][i];
if(v[y] == 2){
int lca = find(d);
ans[id] = min(ans[id] , d[x] + d[y] - 2 * d[lca]);
}
}
v[x] = 2;
}
int main()
{
cin >> T;
while(T--)
{
cin >> n >> m;
memset(h, -1, sizeof h);
idx = 0;
for(int i = 1 ; i <= n ; i++){
int x, y, z;
scanf("%d%d%d" ,&x, &y, &z);
add(x, y, z) , add(y, x, z);
}
// 询问
for(int i = 1 ; i <= m ; i++){
int x , y;
scanf("%d%d" , &n, &m);
if(x == y) ans[i] = 0;
else {
add_query(x, y,i);
ans[i] = 1 << 30;
}
}
tarjan(1);
for(int i = 1 ; i <= m ; i++) printf("%d\n" , ans[i]);
}
return 0;
}