https://ac.nowcoder.com/acm/contest/11224/F
1.无法从k到达所有的景点
显然答案是从k向所有不在k为跟的子树的景点的lca连边即可
2. 所有景点都在k为跟的子树内
枚举确定向哪点连边最优
以i为跟的子树内景点到i的距离, num[i] 以i为跟的子树内景点的个数
假设连边k -> v
那么答案变成 $sum[k] - num[v] * (depth[v] - depth[k])$
枚举k子树内的点即可
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define endl "\n"
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
#define debug(a) cout<<#a<<' '<<a<<endl;
#define cut cout<<"-------------------"<<endl;
typedef long long ll;
typedef pair<int,int> pii;
typedef double db;
const int N = 2e5+10;
const ll mod=1e9 + 7, inf = 0x3f3f3f3f;
mt19937 mrand(time(0));
int rnd(int x) { return mrand() % x;}
ll qmi(ll a,ll b,ll p) {ll res=1;a%=p; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%p;a=a*a%p;}return res;}
ll gcd(ll a,ll b) { return b?gcd(b,a%b):a;}
#define int ll
int n, m , k;
vector<int> g[N];
int b[N];
bool vis[N];
int depth[N], fa[N][25];
vector<int> res; //k子树内的所有点
int sum[N], num[N];//sum[i] 以i为跟的子树内景点到i的距离, num[i] 以i为跟的子树内景点的个数
unordered_map<int, int> mp;
bool check()
{
queue<int> q;
q.push(k);
vis[k] = 1;
while(SZ(q))
{
int t = q.front();
q.pop();
for(int p : g[t])
{
q.push(p);
res.pb(p);
vis[p] = 1;
}
}
for(int i = 1; i <= m; i ++)
if(!vis[b[i]])
return false;
return true;
}
void bfs()
{
queue<int> q;
q.push(1);
depth[1] = 1, depth[0] = 0;
while(SZ(q))
{
int t = q.front();
q.pop();
for(int p : g[t])
{
q.push(p);
depth[p] = depth[t] + 1;
fa[p][0] = t;
for(int i = 1; i <= 20; i ++)
{
int anc = fa[p][i - 1];
fa[p][i] = fa[anc][i - 1];
}
}
}
}
void dfs(int u)
{
if(mp[u])
num[u] ++;
for(int p : g[u])
{
dfs(p);
num[u] += num[p];
sum[u] += sum[p] + num[p];
}
}
int lca(int x, int y)
{
if(depth[x] < depth[y])
swap(x, y);
for(int k = 20; k >= 0; k --)
{
if(depth[fa[x][k]] >= depth[y])
x = fa[x][k];
}
if(x == y)
return x;
for(int k = 20; k >= 0; k --)
{
if(fa[x][k] != fa[y][k])
x = fa[x][k], y = fa[y][k];
}
return fa[x][0];
}
void init()
{
cin >> n;
for(int i = 1; i < n; i ++)
{
int u, v;
cin >> u >> v;
g[u].pb(v);
}
cin >> m >> k;
for(int i = 1; i <= m; i ++)
cin >> b[i], mp[b[i]] = 1;
bfs();
dfs(1);
// cout << lca(5, 6);
// exit(0);
if(!check())//1. k 不可以到达所有的b 向不在k子树的所有点的lca连边
{
// cout <<"YES" << endl;
int f = 0;
int ans = 0;
for(int i = 1; i <= m; i ++)
{
if(!vis[b[i]])
{
if(f == 0)
f = b[i];
else
f = lca(f, b[i]);
// cout << b[i] << ' ';
}
else
{
ans += depth[b[i]] - depth[k];
}
}
int cnt = 0;
for(int i = 1; i <= m; i ++)
{
if(!vis[b[i]])
{
cnt ++;
ans += depth[b[i]] - depth[f];
}
}
cout << ans + cnt<< endl;
}
else//2. k可以到达所有的b
{
int ans = sum[k];
for(int p : res)
{
ans = min(ans, sum[k] - num[p] * (depth[p] - depth[k] - 1));
}
cout << ans;
}
}
void solve()
{
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int _;
// cin>>_;
_ = 1;
while(_--)
{
init();
solve();
}
return 0;
}