代码主体没有变化,但是稍微简化了一下。不需要用0,1,2三个状态来表示,之需要0,1即可。
原因是如果另一个点在当前搜索路径上,那么这个点还没有被并查集合并,因此find找到的也是自己,所以不影响计算结果。
#include <bits/stdc++.h>
using namespace std;
typedef pair<int, int> PII;
const int N = 10010;
const int M = 2*N;
int h[N], e[M], ne[M], w[M], idx;
int dist[N];
vector<PII> q[N];
int res[M];
int p[N];
int n, m;
int st[N];
int find(int x){
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}
void add(int a, int b, int c){
e[idx] = b;
ne[idx] = h[a];
w[idx] = c;
h[a] = idx ++;
}
void dfs(int u, int prev){
for(int i = h[u]; i != -1; i = ne[i]){
int j = e[i];
if(j != prev){
dist[j] = dist[u] + w[i];
dfs(j, u);
}
}
}
void tarjan(int u){
st[u] = 1;
for(int i = h[u]; i != -1; i = ne[i]){
int j = e[i];
if(!st[j]){
tarjan(j);
p[j] = u;
}
}
for(auto query: q[u]){
int id = query.second;
int another = query.first;
if(st[another]){
int anc = find(another);
res[id] = dist[u] + dist[another] - 2*dist[anc];
}
}
}
int main(){
memset(h, -1, sizeof h);
memset(dist, 0x3f3f3f, sizeof dist);
dist[1] = 0;
cin >> n >> m;
for(int i = 1; i < n; i ++){
int x, y, k;
cin >> x >> y >> k;
add(x, y, k);
add(y, x, k);
}
for(int i = 0; i <= n; i ++) p[i] = i;
dfs(1, -1);
for(int i = 1; i <= m; i ++){
int x, y;
cin >> x >> y;
q[x].push_back({y, i});
q[y].push_back({x, i});
}
tarjan(1);
for(int i = 1; i <= m; i ++){
cout << res[i] << endl;
}
return 0;
}