Tarjan算法
$O(n + m)$
解题思路:首先根据给出的中序序列和前序序列建树,由于题目给出的结点值不便于后续处理查询,因此先将每个数映射到1~n,在读取中序序列时记录每个结点的编号,正好让结点在中序遍历中的位置作为其编号,可以便于我们建立二叉树。建树完成后,记录所有的合法查询,不合法的查询将其做好标记,query[u]记录了{v, idx},使用离线的Tarjan算法即可求出u, v的祖先节点编号,最后根据编号对应的数值输出即可。
C++ 代码
#include <cstdio>
#include <cstring>
#include <vector>
#include <unordered_map>
using namespace std;
const int N = 10010, M = 1010, INF = 1e9;
typedef pair<int, int> PII;
unordered_map<int, int> ma; //节点值到节点标号的映射
unordered_map<int, int> pre; //节点标号到节点值的映射
int inorder[N], preorder[N];
int lson[N], rson[N], pivot; // lson[i]表示i的左儿子, rson[i]表示i的右儿子, 为0表示空指针
PII a[M];
vector<PII> query[N];
int ans[M], vis[N], f[N];
int n, m;
void dfs(int pos, int l, int r){
if(l >= r) return;
if(pos > l){
lson[pos] = ma[preorder[++pivot]];
dfs(lson[pos], l, pos - 1);
}
if(pos < r){
rson[pos] = ma[preorder[++pivot]];
dfs(rson[pos], pos + 1, r);
}
}
int find(int x){
if(f[x] == x) return x;
return f[x] = find(f[x]);
}
void tarjan(int u){
vis[u] = 1;
if(lson[u] && !vis[lson[u]]){
tarjan(lson[u]);
f[lson[u]] = u;
}
if(rson[u] && !vis[rson[u]]){
tarjan(rson[u]);
f[rson[u]] = u;
}
for(int i = 0; i < query[u].size(); i++){
int v = query[u][i].first, idx = query[u][i].second;
if(vis[v] == 2){
int parent = find(v);
ans[idx] = parent;
}
}
vis[u] = 2;
}
int main(){
scanf("%d %d", &m, &n);
int id = 0;
for(int i = 1; i <= n; i++){
scanf("%d", &inorder[i]);
ma[inorder[i]] = ++id;
pre[id] = inorder[i];
}
for(int i = 1; i <= n; i++) scanf("%d", &preorder[i]), f[i] = i;
pivot = 1;
int root = ma[preorder[1]];
dfs(root, 1, n); //中序和前序遍历建树, 直接用结点编号代替原数值
int u, v;
for(int i = 1; i <= m; i++){
scanf("%d %d", &u, &v);
a[i].first = u, a[i].second = v;
if(ma.find(u) == ma.end() || ma.find(v) == ma.end()){
ans[i] = INF;
}
else{
u = ma[u], v = ma[v];
if(u != v){
query[u].push_back({v, i});
query[v].push_back({u, i});
}
else ans[i] = u;
}
}
tarjan(root); // tarjan算法离线求LCA
for(int i = 1; i <= m; i++){
int u = a[i].first, v = a[i].second;
if(ans[i] == INF){
if(ma.find(u) == ma.end() && ma.find(v) == ma.end())
printf("ERROR: %d and %d are not found.\n", u, v);
else if(ma.find(u) == ma.end())
printf("ERROR: %d is not found.\n", u);
else
printf("ERROR: %d is not found.\n", v);
}
else{
int parent = pre[ans[i]];
if(parent == u){
printf("%d is an ancestor of %d.\n", u, v);
}
else if(parent == v){
printf("%d is an ancestor of %d.\n", v, u);
}
else{
printf("LCA of %d and %d is %d.\n", u, v, parent);
}
}
}
return 0;
}