题目描述
给定一棵包含 n 个节点的有根无向树,节点编号互不相同,但不一定是 1∼n。
有 m个询问,每个询问给出了一对节点的编号 x 和 y,询问 x 与 y的祖孙关系。
输入格式
输入第一行包括一个整数 表示节点个数;
接下来 n行每行一对整数 a 和 b,表示 a 和 b 之间有一条无向边。如果 b 是 −1,那么 a就是树的根;
第 n+2行是一个整数 m表示询问个数;
接下来 m行,每行两个不同的正整数 x 和 y,表示一个询问。
输出格式
对于每一个询问,若 x是 y 的祖先则输出 1,若 y 是 x 的祖先则输出 2,否则输出 0。
数据范围
1≤n,m≤4×104,
1≤每个节点的编号≤4×104
样例
输入样例:
10
234 -1
12 234
13 234
14 234
15 234
16 234
17 234
18 234
19 234
233 19
5
234 233
233 12
233 13
233 15
233 19
输出样例:
1
0
0
0
2
求解最近公共祖先,一般有三种方法,见下图:
- 向上标记法
先将一个点向上一直到根所经过的节点记录下来,再对另一个点向上,如果到达某个点被标记过那么就是他们的公共祖先。 - 倍增思想
某两个点离他们的最近公共祖先的距离,也就是相差的层数,用二进制数表示,那么就是可以分解为若干个2^j次方的和。那么如果向上跳2^j步,如果刚好跳到公共祖先或是跳过了公共祖先,两个点到达的节点都是相同的,所以当跳2^j步到达相同节点时无法判断是否是最近公共祖先,所以从大到小枚举j,如果当前的j跳到的节点不相同,而>j跳到的节点相同,那么跳2^j步,一直搜索到j=0为止,那么最终到达的是两个节点的最近公共节点的下一层。(从这一层再跳2^0步就是最近公共祖先)
所以我们要做一些初始化,将每个节点跳2^j步会到达的节点存下来,并且存储每个点所在的层数,方便将不同层的点先放在同一层再寻找祖先。并且添加哨兵,如果跳过了根节点,那么置为0,第0号节点的层数置为0。哨兵的作用会在代码体现 - Tarjan算法
见题1171:距离
C++ 代码
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=40010,M=N*2;
int n,m;
int h[N],e[M],ne[M],idx;
int depth[N];
int q[N];
int fa[N][16];//最多40000个点,2^16次方就六万多了,最大跳2^15
void add(int a,int b){
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void bfs(int root){
memset(depth,0x3f,sizeof depth);
depth[0]=0;depth[root]=1;
int hh=0,tt=-1;
q[++tt]=root;
while(hh<=tt){
int t=q[hh++];
for(int i=h[t];~i;i=ne[i]){
int j=e[i];
if(depth[j]>depth[t]+1){
depth[j]=depth[t]+1;
q[++tt]=j;
fa[j][0]=t;//向上跳一步就是父节点
//向上跳2^k步,那么分解为2^(k-1)+2^(k-1),就是先跳2^(k-1)后,再跳2^(k-1)
for(int k=1;k<=15;k++){
fa[j][k]=fa[fa[j][k-1]][k-1];
}
}
}
}
}
int lca(int x,int y){
//越靠近根,层数是越小的
if(depth[x]<depth[y]) swap(x,y);//将层数大的放在x,将x向上走到与y同层再找祖先
for(int i=15;i>=0;i--){
//在设置哨兵depth[0]=0后,即使跳出去了,那么depth为0,依旧<depth[y]
if(depth[fa[x][i]]>=depth[y])//每次跳到最近的不越过的y的层
x=fa[x][i];
}
if(x==y) return y;//如果在同一层后就已经相等了,说明x和y本身就有祖宗和子孙的关系
//在同一层后进行最近公共祖先寻找
for(int i=15;i>=0;i--){
//fa 当跳出去的时候都为0,就是跳到了虚无的0节点,fa[x][i]和fa[y][i]还是相等的,是0
if(fa[x][i]!=fa[y][i]){
x=fa[x][i];
y=fa[y][i];
}
}
return fa[x][0];//跳到公共祖先的下一层后再上跳一步就是祖先
}
int main(){
cin>>n;
memset(h,-1,sizeof h);
int root=0;
for(int i=0;i<n;i++){
int a,b;
cin>>a>>b;
if(b==-1) root=a;
else add(a,b),add(b,a);
}
bfs(root);
cin>>m;
while(m--){
int x,y;
cin>>x>>y;
int p=lca(x,y);
if(p==x) cout<<1<<endl;
else if(p==y) cout<<2<<endl;
else cout<<0<<endl;
}
return 0;
}