给定一棵有根树,若节点 z 既是节点 x 的祖先,也是节点 y 的祖先,则称 z 是 x,y
的公共祖先。在 x,y 的所有公共祖先中,深度最大的一个称为 x,y 的最近公共祖先,记为 LCA(x,y)。
LCA(x,y) 是 x 到根的路径与 y 到根的路径的交会点。它也是 x 与 y 之间的路径上深度最小的节点。求最近公共祖先的方法通常有三种:
向上标记法
从 x 向上走到根节点,并标记所有经过的节点。
从 y 向上走到根节点,当第一次遇到已标记的节点时,就找到了 LCA(x,y)。
对于每个询问,向上标记法的时间复杂度最坏为 O(n)。
树上倍增法
树上倍增法是一个很重要的算法。除了求 LCA 之外,它在很多问题中都有广泛应用。
设 F[x,k] 表示 x 的 2k 辈祖先 ,即从 x 向根节点走 2k 步到达的节点。特别地,若该节点不存在,则令 F[x,k]=0。F[x,0] 就是 x 的父节点。除此之外,
任意的 k∈[1,logn],F[x,k]=F[F[x,k−1],k−1] 。
这类似于一个动态规划的过程,“阶段” 就是节点的深度。因此,我们可以对树进行广度优先遍历,按照层次顺序,在节点入队之前,计算它在 F 数组中相应的值。以上部分是预处理,时间复杂度为 O(nlogn),之后可以多次对不同的 x,y 计算 LCA,每次询问的时间复杂度为 O(logn)。
基于 F 数组计算 LCA(x,y) 分为以下几步:
- 设 d[x] 表示 x 的深度。不妨设 d[x]≥d[y](否则可交换 x,y )。
- 用二进制拆分思想,把 x 向上调整到与 y 同一深度。具体来说,就是依次尝试从 x 向上走 k=2logn,…,21,20 步,检查到达的节点是否比 y 深。在每次检查中,若是,则令 x=F[x,k]。
- 若此时 x=y,说明已经找到了 LCA,LCA 就等于 y。
这就是上面的图中的第三种情况。
4:用二进制拆分思想,把 x,y 同时向上调整,并保持深度一致且二者不相会。具体来说,就是依次尝试把 x,y 同时向上走 k=2logn,…,21,20 步,在每次尝试中,若 F[x,k]≠F[y,k](即仍未相会),则令 x=F[x,k],y=F[y,k]。
5:此时 x,y 必定只差一步就相会了,它们的父节点 F[x,0] 就是 LCA。
多次查询树上两点之间的距离,时间复杂度为 O((n+m)logn)。
// 多次查询树上两点之间的距离板子
// 树上倍增
const int N = 50010 , M = 2 * N;
int f[N][20] , d[N] , dist[N]; // d[]是 depth[], dist[] 存储该点到根节点的最短距离
int T ,t , n , m;
int h[N] , e[M] , ne[M] , w[M] , idx;
int q[N];
void add(int a, int b, int c){
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++;
}
// 预处理 f[x, k]不存在时=0; f[x, k] = f[f[x, k - 1] ,k - 1]
void bfs(int){
memset(dist , 0x3f , sizeof dist);
d[0] = 0; // 哨兵
d[root] = 1;
int hh = 0 , tt = 0;
q[0] = root;
while(hh <= tt)
{
int x = q[hh++];
for(int i = h[x] ; ~i ; i = ne[i])
{
int y = e[i];
if(d[y]) continue;
d[y] = d[x] + 1; // 深度 + 1
dist[y] = dist[x] + w[i];// 新距离
f[y][0] = x;
for(int j = 1 ; j <= t ; j++) // t = (int)(log(n) / log(2)) + 1;
f[y][j] = f[f[y][j-1]][j-1];
q[++tt] = y;
}
}
}
// 查询 x , y 的公共祖先
int lca(int x , int y){ // d[x] <= d[y]
if(d[x] > d[y]) swap(x, y);
// 将 y 向上调整 至与 x 同一深度
for(int k = t ; k >= 0 ; k--)
if(d[f[y][k]] >= d[x]) y = f[y][k];
if(x == y) return x;
// 将 x , y 同时向上调整,并且保证深度一致且不会相会
for(int k = t ; k >= 0 ; k--)
if(f[x][k] != f[y][k]){
x = f[x][k];
y = f[y][k];
}
return f[x][0];
}
int main()
{
cin >> T;
while(T--)
{
cin >> n >> m;
t = (int)(log(n) / log(2)) + 1;
memset(h, -1, sizeof h);
idx = 0;
memset(d , 0 , sizeof d);
// 读入一颗树树
for(int i = 1 ; i <= n ; i++){
int x, y, z;
scanf("%d%d%d" ,&x, &y ,&z);
add(x, y, z) , add(y, x, z);
}
bfs();
// 回答询问
for(int i = 1 ; <= m ; i++){
int x, y;
scanf("%d%d" , &x ,&y);
// x, y 两点之间的距离
printf("%d\n" , dist[x]+dist[y] - 2 * dist[lca(x, y)]);
}
}
return 0;
}
LCA 的Tarjan 算法
Tarjan 算法本质上是使用并查集对 “ 向上标记法” 的优化。它是一个离线算法,需要把 m 个询问一次性读入,统一计算,最后统一输出。时间复杂度为 O(n+m)。在深度优先遍历的任意时刻,树中节点分为三类:
- 已经访问完毕并且回溯的节点。在这些节点上标记一个整数 2。
- 已经开始递归,但尚未回溯的节点。这些节点就是当前正在访问的节点 x 以及 x 的祖先。在这些节点上标记一个整数 1。
- 尚未访问的节点。这些节点没有标记。
对于正在访问的节点 x,它到根节点的路径已经标记为 1。若 y 是已经访问完毕并且回溯的节点,则 LCA(x,y) 就是从 y 向上走到根,第一个遇到的标记为 1 的节点。
可以利用并查集进行优化,当一个节点获得整数 2 的标记时,把它所在的集合合并到它的父节点所在的集合中(合并时它的父节点标记一定为 1,且单独构成一个集合)。
这相当于每个完成回溯的节点都有一个指针指向它的父节点,只需查询所在集合的代表元素(并查集的 find 操作),就等价于从 y 向上一直走到一个开始递归但尚未回溯的节点(具有标记 1),即LCA(x,y)。
在 x 回溯之前,标记情况与合并情况如下图所示。黑色表示标记为 1,灰色表示标记为 2,白色表示没有标记,箭头表示执行了合并操作。
此时扫描与 x 相关的所有询问,若询问当中的另一个点 y 的标记为 2,就知道了该询问的回答应该是 y 在并查集中的代表元素(并查集中 find(y) 函数的结果)。
多次查询树上两点之间的距离,时间复杂度为 O(n+m)。
合并操作是在回溯完当前节点之后进行合并
const int N = 50010 , M = N * 2;
int T, n , m , t;
int h[N] , e[M] , ne[M] ,w[M] ,idx;
int p[N] , dist[N] , v[N] , lca[N] , ans[N];
vector<int> query[N]; // 存 询问的另一节点
vector<int> query_id[N]; // 存 询问的编号
void add(int a, int b, int c){
e[idx] = b, ne[idx] = h[a] , w[idx] = c , h[a] = idx++
}
void add_query(int x, int y, int id){
query[x].emplace_back(y) , query_id[x].emplace_back(id);
query[y].emplace_back(x) , query_id[y].emplace_back(id);
}
int find(int x){
if(x != p[x]) p[x] = find(p[x]);
return p[x];
}
// 处理 x 节点到其他节点的距离
void tarjan(int x){
// 处理在 遍历的分支
v[x] = 1;
for(int i = h[x] ; ~i ; i = ne[i]){
int y = e[i];
if(v[y]) continue;
dist[y] = dist[x] + w[i];
tarjan(y);
p[y] = x;
}
// 计算距离
for(int i = 0 , i < query[x].size ; i++)
{
int y = query[x][i] , id = query_id[x][i];
if(v[y] == 2){
int lca = find(d);
ans[id] = min(ans[id] , d[x] + d[y] - 2 * d[lca]);
}
}
v[x] = 2;
}
int main()
{
cin >> T;
while(T--)
{
cin >> n >> m;
memset(h, -1, sizeof h);
idx = 0;
for(int i = 1 ; i <= n ; i++){
int x, y, z;
scanf("%d%d%d" ,&x, &y, &z);
add(x, y, z) , add(y, x, z);
}
// 询问
for(int i = 1 ; i <= m ; i++){
int x , y;
scanf("%d%d" , &n, &m);
if(x == y) ans[i] = 0;
else {
add_query(x, y,i);
ans[i] = 1 << 30;
}
}
tarjan(1);
for(int i = 1 ; i <= m ; i++) printf("%d\n" , ans[i]);
}
return 0;
}