分析
思路:
1、存储所有的查询,以每个起始点为集合来存储对应的目标点以及对应的序号。
2、维护一个dist数组来存储每个节点到节点1(默认设置其为根节点)的距离(使用dfs来进行求取)。
3、维护一个并查集,在回溯的过程中去维护并查集(目的是能够去构建当前的指定子树的祖先结点),同时在回溯的过程中去遍历以当前u结点初始节点出发的集合,遍历每一个目标节点,若是该目标节点已经在此之前回溯完成了,那么此时可以进行两个结点最小路径的计算(起始结点距离根节点距离 + 目标节点距离根节点距离 - 2 * 初始目标节点的祖先结点距离根节点距离 = 两个结点最短路径
)。
看下这个过程:起始点为蓝色,目标点为红色
当我们去回溯到红色点时,由于之前查询节点添加了两个方向所以可以查到红色的目标值为蓝色点,而此时蓝色点已经回溯过了为2,此时就可以去找到蓝色的祖先结点,接着求得两个点的最小路径了。
实际上对于祖先结点的定位查找我们是根据在回溯的过程中去维护一个并查集,这样我们就能够在回溯时找到祖先节点从而求得最小距离:起始结点距离根节点距离 + 目标节点距离根节点距离 - 2 * 初始目标节点的祖先结点距离根节点距离 = 两个结点最短路径
题解思路:tarjan 离线LCA
复杂度分析:时间复杂度O(m+n);空间复杂度O(n)
import java.util.*;
import java.io.*;
class Main {
static final BufferedReader cin = new BufferedReader(new InputStreamReader(System.in));
static final int N = 200010, M = N * 2;
//读入结点
static int n, m;
//单链表式的邻接表
static int[] e = new int[M], ne = new int[M], h = new int[N], w = new int[M];
static int idx;
//维护查询集合
static ArrayList<int[]>[] queries = new ArrayList[N];
//dist:记录每个节点距离根节点1的长度
//p:维护并查集
//res:维护查询结果
static int[] dist = new int[N], p = new int[N], res = new int[N];
//记录在tarjan中是否进行访问过节点 1表示已访问,2表示已回溯完成
static int[] st = new int[N];
//添加节点到邻接表中
static void add(int a, int b, int c) {
e[idx] = b;
w[idx] = c;
ne[idx] = h[a];
h[a] = idx++;
}
//并查集查找祖先结点
static int find(int u) {
if (p[u] != u) p[u] = find(p[u]);
return p[u];
}
//dfs
static void dfs(int u, int fa) {
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j != fa) {
//记录距离,注意w[i]这里是在单链表中存储的值
dist[j] = dist[u] + w[i];
dfs(j, u);
}
}
}
//tarjan离线计算
static void tarjan(int u) {
st[u] = 1;//1表示已经访问过了
//遍历单链表
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
//表示还没有访问过
if (st[j] == 0) {
tarjan(j);
//维护并查集节点
p[j] = u;
}
}
//访问所有的查询节点
ArrayList<int[]> queriesList = queries[u];
if (queriesList != null) {
for (int[] query: queriesList) {
//y表示目标点,id表示查询的序号用于存储结果值
int y = query[0], id = query[1];
if (st[y] == 2) {
//通过并查集找到当前y的祖先结点
int anc = find(y);
res[id] = dist[y] + dist[u] - 2 * dist[anc];
}
}
}
//整体回溯完成设置st状态
st[u] = 2;
}
public static void main(String[] args) throws Exception{
String[] ss = cin.readLine().split(" ");
n = Integer.parseInt(ss[0]);
m = Integer.parseInt(ss[1]);
//初始化所有单链表的头结点
Arrays.fill(h,-1);
//读入初始节点与目标节点
for (int i = 1; i < n; i ++ ) {
ss = cin.readLine().split(" ");
int a = Integer.parseInt(ss[0]);
int b = Integer.parseInt(ss[1]);
int c = Integer.parseInt(ss[2]);
add(a, b, c);
add(b, a, c);
}
//读入查询节点的最小路径
for (int i = 0; i < m; i ++) {
ss = cin.readLine().split(" ");
int a = Integer.parseInt(ss[0]);
int b = Integer.parseInt(ss[1]);
if (a == b) continue;
if (queries[a] == null) queries[a] = new ArrayList<>();
if (queries[b] == null) queries[b] = new ArrayList<>();
//添加对应查询集合的目标节点与对应的查询次序
queries[a].add(new int[]{b, i});
queries[b].add(new int[]{a, i});
}
//进行dfs来取得每一个节点距离根节点1的长度
dfs(1, -1);
//初始化并查集
for (int i = 1; i <= n; i ++) p[i] = i;
//进行离线计算
tarjan(1);
//打印所有结果
for (int i = 0; i < m; i ++) {
System.out.println(res[i]);
}
}
}