题 目 链 接
我焯,太难了...我不会啊啊啊啊...
参考题解:
https://blog.csdn.net/weixin_43960414/article/details/117300181?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522165104864216782350918730%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=165104864216782350918730&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-117300181.142^v9^pc_search_result_control_group,157^v4^control&utm_term=C.+Trees+of+Tranquillity&spm=1018.2226.3001.4449
化简一下题意就是:
给两棵树T1,T2,根节点都是1,找出一个点集S,
使得在T1中存在一条链包含点集中的所有点 && 点集中任意两点在T2中没有祖先关系。
分析一下:(图论dfs欧拉序列 + 数据结构set)
(1)首先,任意两点在树中没有祖先关系代表着什么?
说明这两点在树中的欧拉序列的区间没有交集!!!
(一般树中节点祖先关系的判断可以用欧拉序列!!!)
而且注意到一点:欧拉序列中点的区间只有包含和无交集两种情况!!!
(2)所以下面就好分析了:
我们直接在T1中dfs,这样就是链,然后在T2的欧拉序列中用区间判断!
对于两个点,如果他们的区间没有交集,则要选;
否则,我们贪心的想,那肯定选区间小的啊!
(3)那么我们用什么数据结构去维护区间 以及 如何维护呢???
首先,我们把T2中每个点欧拉序列区间的左右端点都存一下L[x]R[x],而且存一下映射;
然后我们用set去维护每个点的区间的左端点L[x];
对于每个dfs到的点x,我们在set中找第一个>=L[x]的左端点,然后判断一下是否包含或无交集,维护一下就行。
记得dfs最后要还原现场
(4)写的时候发现一个逻辑易错的点:
我们在set二分找到it时,要先判断if(it != s.begin())再判断if(it == s.end() || (*it) > R[u])
不然我们可以一直找到it == s.end()但是it和前面的存在包含关系,而大区间又没有删掉,所以出错.
太妙啊!
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<ll,ll> PII;
const int maxn = 1e6+7;
const int mod = 1e9+7;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
vector<int> g1[maxn],g2[maxn];
int L[maxn];
int R[maxn];
int mp[maxn];
int idx;
int n;
int ans;
set<int> s;
void dfs_oula(int u,int fa)
{
L[u] = ++idx;
for(auto j : g2[u])
{
if(j == fa) continue;
dfs_oula(j,u);
}
R[u] = ++idx;
mp[L[u]] = R[u];
}
void dfs(int u,int fa)
{
int num = 0;
if(!s.size()) s.insert(L[u]);
else
{
auto it = s.lower_bound(L[u]);
if(it != s.begin())
{
it --;
if(mp[(*it)] >= R[u])
{
num = (*it);
s.erase(num);
s.insert(L[u]);
}
else
{
it ++;
if(it == s.end() || *it > R[u]) s.insert(L[u]);
}
}
else if(it == s.end() || (*it) > R[u]) s.insert(L[u]);
}
ans = max(ans,(int)s.size());
for(auto x : g1[u])
{
if(x == fa) continue;
dfs(x,u);
}
if(s.find(L[u]) != s.end()) s.erase(L[u]);
if(num) s.insert(num);
}
void solve()
{
scanf("%d",&n);
idx = 0;
s.clear();
for(int i=1;i<=n;i++)
{
g1[i].clear();
g2[i].clear();
mp[i] = 0;
}
for(int i=2;i<=n;i++)
{
int x; cin>>x;
g1[x].push_back(i);
g1[i].push_back(x);
}
for(int i=2;i<=n;i++)
{
int x; cin>>x;
g2[x].push_back(i);
g2[i].push_back(x);
}
dfs_oula(1,-1);
ans = 0;
dfs(1,-1);
cout<<ans<<endl;
}
int main()
{
int t;
cin>>t;
while(t--)
{
solve();
}
return 0;
}