题目描述
tarjan算法把所有点分成三类:
1. 已经遍历过,且回溯过的点
2. 正在搜索的分支
3. 还未搜索的分支
#include <iostream>
#include <cstring>
#include <vector>
using namespace std;
typedef pair<int,int> pii;
const int N=20010;
int n,m;
int h[N],e[N],ne[N],w[N],idx,st[N],dist[N];
int res[N];
int p[N];
vector<pii> query[N];
void add(int a,int b,int c)
{
e[idx]=b;
ne[idx]=h[a];
w[idx]=c;
h[a]=idx++;
}
int dfs(int u,int fa)
{
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
if(j==fa) continue;
dist[j]=dist[u]+w[i];
dfs(j,u);
}
}
int find(int x)
{
if(p[x]!=x) p[x]=find(p[x]);
else return x;
}
void tarjan(int u)
{
//开始遍历u这个点
st[u]=1;
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
if(!st[j])
{
//如果没遍历j,则遍历j
tarjan(j);
//遍历完了再合并
p[j]=u;
}
}
//在遍历点u后 扫描所有和u相关的询问 如果询问中另一个点b已经被遍历+回溯过了,则可以计算一个询问
for(auto item:query[u])
{
int i=item.second;
int b=item.first;
//如果查询的点,已遍历完
//只有回溯完的点,与当前点的公共祖先,是b的根节点
//回溯完到根节点后把根节点左边的子树上的点都用并查集合并到根节点
if(st[b]==2)
{
int anc=find(b);
res[i]=dist[u]+dist[b]-2*dist[anc];
}
}
st[u]=2;
}
int main()
{
cin>>n>>m;
memset(h,-1,sizeof h);
for(int i=1;i<=n;i++)
{
p[i]=i;
}
for(int i=1;i<n;i++)
{
int a,b,c;
cin>>a>>b>>c;
add(a,b,c);
add(b,a,c);
}
for(int i=1;i<=m;i++)
{
int a,b;
cin>>a>>b;
query[a].push_back({b,i});
query[b].push_back({a,i});
}
dfs(1,-1);
tarjan(1);
for(int i=1;i<=m;i++)
{
cout<<res[i]<<endl;
}
return 0;
}
二刷,又写了一遍倍增
总之这题就是求最近公共祖先
然后两点之间的距离=dist[i]+dist[j]-2*dist[anc]
#include <iostream>
#include <cstring>
#include <vector>
#include <queue>
#include <algorithm>
using namespace std;
const int N = 2e4 + 10, M = N * 2;
int n,m;
int h[N],e[M],ne[M],w[M],idx,st[N],dist[N];
int fa[N][18];
int depth[N];
queue<int> q;
void add(int a,int b,int c)
{
e[idx]=b;
ne[idx]=h[a];
w[idx]=c;
h[a]=idx++;
}
void bfs()
{
memset(depth,0x3f,sizeof depth);
//哨兵
depth[0]=0;
//默认1为根节点
depth[1]=1;
q.push(1);
while(q.size())
{
int t=q.front();
q.pop();
for(int i=h[t];i!=-1;i=ne[i])
{
int j=e[i];
if(depth[j]>depth[t]+1)
{
depth[j]=depth[t]+1;
dist[j]=dist[t]+w[i];
fa[j][0]=t;
q.push(j);
for(int k=1;k<=16;k++)
{
fa[j][k]=fa[fa[j][k-1]][k-1];
}
}
}
}
}
int lca(int a, int b)
{
if(depth[a] < depth[b]) swap(a, b);
for(int k = 16; k >= 0; k -- )
if(depth[fa[a][k]] >= depth[b])
a = fa[a][k];
if(a == b) return b;
for(int k = 16; k >= 0; k -- )
{
if(fa[a][k] != fa[b][k])
{
a = fa[a][k];
b = fa[b][k];
}
}
return fa[a][0];
}
int main()
{
cin>>n>>m;
memset(h,-1,sizeof h);
for(int i=1;i<n;i++)
{
int a,b,c;
cin>>a>>b>>c;
add(a,b,c);
add(b,a,c);
}
bfs();
for(int i=1;i<=m;i++)
{
int a,b;
cin>>a>>b;
//找到a和b的最近公共祖先
int anc = lca(a,b);
int res=dist[a]+dist[b]-2*dist[anc];
cout<<res<<endl;
}
return 0;
}