题目描述
f[i,j]表示从i开始,向上走2^j步所能达到的节点
1. 先将两个节点跳到同一层
2. 让两个点同时往上跳,一直跳到最近公共祖先的下一层
3. 有一个哨兵,如果从i开始跳2^j步会跳过根节点,则fa[i,j]=0
4. 预处理分为两步:预处理deep数组;预处理倍增数组
#include <iostream>
#include <cstring>
#include <queue>
#include <algorithm>
using namespace std;
const int N=40010,M=2*N;
int n,m;
int h[N],ne[M],e[M],idx;
int deep[N],fa[N][16];
queue<int> q;
void add(int a,int b)
{
e[idx]=b;
ne[idx]=h[a];
h[a]=idx++;
}
void bfs(int root)
{
memset(deep,0x3f,sizeof deep);
//哨兵
deep[0]=0;
deep[root]=1;
q.push(root);
while(q.size())
{
int t=q.front();
q.pop();
for(int i=h[t];i!=-1;i=ne[i])
{
int j=e[i];
if(deep[j]>deep[t]+1)
{
deep[j]=deep[t]+1;
q.push(j);
//j的向上一步为t
fa[j][0]=t;
for(int k=1;k<=15;k++)
{
//j向上跳2^k步,等于j向上跳2^k-1步之后再跳2^k-1步
fa[j][k]=fa[fa[j][k-1]][k-1];
}
}
}
}
}
//寻找最近公共祖先
int lca(int a,int b)
{
//首先让a在b下面
if(deep[a]<deep[b]) swap(a,b);
for(int k=15;k>=0;k--)
{
//如果还没达到b,继续跳
if(deep[fa[a][k]]>=deep[b])
{
a=fa[a][k];
}
}
if(a==b) return a;
//否则两个一起跳
for(int k=15;k>=0;k--)
{
if(fa[a][k] != fa[b][k])
{
a=fa[a][k];
b=fa[b][k];
}
}
return fa[a][0];
}
int main()
{
cin>>n;
int root=0;
memset(h,-1,sizeof h);
for(int i=0;i<n;i++)
{
int a,b;
cin>>a>>b;
if(b==-1) root=a;
else
{
add(a,b);
add(b,a);
}
}
bfs(root);
cin>>m;
while(m--)
{
int a,b;
cin>>a>>b;
int p=lca(a,b);
if(p==a) cout<<"1"<<endl;
else if(p==b) cout<<"2"<<endl;
else
cout<<0<<endl;
}
return 0;
}
二刷,再次记住:BFS求最短路时可以不用st数组,因为第一次遍历到一定是最短的!
注意lca中,当a跳到与b同一层时,若ab相等注意返回
#include <iostream>
#include <cstring>
#include <queue>
#include <algorithm>
using namespace std;
const int N=40010;
int h[N],ne[2*N],e[2*N],idx;
int depth[N],fa[N][16];
int n,m;
void add(int a,int b)
{
e[idx]=b;
ne[idx]=h[a];
h[a]=idx++;
}
void bfs(int root)
{
memset(depth,0x3f,sizeof depth);
depth[0]=0;
depth[root]=1;
queue<int> q;
q.push(root);
while(q.size())
{
int t=q.front();
q.pop();
//bfs 搜到的第一次,一定是最短的,因此不需要st数组
for(int i=h[t];i!=-1;i=ne[i])
{
int j=e[i];
//如果遍历到他的父节点的话,根本不会进入下面这个if,因此不用设定st
if(depth[j] > depth[t]+1)
{
depth[j]=depth[t]+1;
//j往上跳2的0次方步为t
fa[j][0]=t;
for(int k=1;k<=15;k++)
{
fa[j][k]=fa[fa[j][k-1]][k-1];
}
q.push(j);
}
}
}
}
int lca(int a,int b)
{
if(depth[a]<depth[b]) swap(a,b);
//先让a跳到与b同一层,注意一定要让k从大到小
for(int k=15;k>=0;k--)
{
if(depth[fa[a][k]]>=depth[b])
{
a=fa[a][k];
}
}
//这一句不能忘!!
if(a==b) return a;
//两个节点再同时往上跳,跳到公共祖先的下一层
//注意也要让k从大到小跳,因为二进制要从大到小减,这样保证能表示每一个数
//如果跳出去了,会跳到哨兵,此时f[a][k]==f[b][k],不会更新a和b
for(int k=15;k>=0;k--)
{
if(fa[a][k]!=fa[b][k])
{
b=fa[b][k];
a=fa[a][k];
}
}
return fa[a][0];
}
int main()
{
cin>>n;
memset(h,-1,sizeof h);
int root;
for(int i=1;i<=n;i++)
{
int a,b;
cin>>a>>b;
if(b==-1) root=a;
else
{
add(a,b);
add(b,a);
}
}
bfs(root);
cin>>m;
while(m--)
{
int a,b;
cin>>a>>b;
int res=lca(a,b);
if(res==a) cout<<1<<endl;
else if(res==b) cout<<2<<endl;
else
cout<<0<<endl;
}
return 0;
}