题目描述
难度分:1700
输入T(≤104)表示T组数据。所有数据的n之和≤2×105。
每组数据输入n、k(1≤k≤n≤2×105)和长为k的数组a(1≤a[i]≤n)。
然后输入一棵无向树的n−1条边,节点编号从1到n。
我们标记了树上的k个节点,把这k个节点的编号记录在数组a中。
定义f[i]表示节点i到所有被标记节点的最大距离。输出mini∈[1,n]f[i]。
输入样例1
6
7 3
2 6 7
1 2
1 3
2 4
2 5
3 6
3 7
4 4
1 2 3 4
1 2
2 3
3 4
5 1
1
1 2
1 3
1 4
1 5
5 2
4 5
1 2
2 3
1 4
4 5
10 8
1 2 3 4 5 8 9 10
2 10
10 5
5 3
3 1
1 7
7 4
4 9
8 9
6 1
10 9
1 2 4 5 6 7 8 9 10
1 3
3 9
9 4
4 10
10 6
6 7
7 2
2 5
5 8
输出样例1
2
2
0
1
4
5
输入样例2
3
6 1
3
1 2
1 3
3 4
3 5
2 6
5 3
1 2 5
1 2
1 3
2 4
3 5
7 1
2
3 2
2 6
6 1
5 6
7 6
4 5
输出样例2
0
2
0
算法
换根DP
比较明显的换根树形DP
,先以1为根进行一次DFS
,得到f[1]。
树形DP
状态定义
f[u]表示以u(整个树的树根定为1)为根的子树中,被标记的节点距离u的最大距离。
状态转移
如果u被标记了,初始化f[u]=0,否则f[u]=−∞。
遍历u的直接子节点,子树v中被标记的节点到u的距离就是先到v,再到u的距离,状态转移方程为f[u]=1+minvf[v]。
DFS
换根
接下来再做一次DFS
,得到所有的f[i]。再开一个DP
数组更加清晰,把题中的f[i]存成dp[i]。定义递归函数dfs(u,up),其中up表示u子树(以1为根的意义下)以外所有节点中被标记的节点距离u的最大距离,此时就有dp[u]=max(f[u],up)。然后进行状态转移,遍历u的所有子节点v,u子树之外的标记节点到某个节点v,需要从u走一步到v,因此距离是up+1。而v的兄弟节点v′子树中,被标记的节点需要先走到v′,再向上走一步到u,最后向下走一步到v,需要多两步。
所以v的up值应该是max(up+1,2+maxv′f[v′]),v′是v在子树u中的兄弟节点。maxv′f[v′]可以通过前后缀分解快速求得。
复杂度分析
时间复杂度
对树进行两次DFS
,时间复杂度为O(n)。
空间复杂度
树的邻接表graph,标记数组flag,以及两个DP
数组的空间都是线性的,因此额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 200010, INF = 0x3f3f3f3f;
vector<int> graph[N];
int T, n, k, res, flag[N], f[N], dp[N];
void dfs1(int u, int fa) {
f[u] = flag[u]? 0: -INF;
for(int v: graph[u]) {
if(v == fa) continue;
dfs1(v, u);
f[u] = max(f[u], 1 + f[v]);
}
}
void dfs2(int u, int fa, int up) {
dp[u] = max(f[u], up);
vector<int> vals;
for(int v: graph[u]) {
if(v == fa) continue;
vals.push_back(f[v]);
}
if(vals.empty()) {
return;
}
int sz = vals.size();
vector<int> premax(sz), sufmax(sz);
premax[0] = vals[0];
sufmax[sz - 1] = vals[sz - 1];
for(int i = 1; i < sz; i++) {
premax[i] = max(premax[i - 1], vals[i]);
sufmax[sz - 1 - i] = max(sufmax[sz - i], vals[sz - 1 - i]);
}
int i = 0;
for(int v: graph[u]) {
if(v == fa) continue;
int left = i > 0? premax[i - 1]: -INF;
int right = i + 1 < sz? sufmax[i + 1]: -INF;
dfs2(v, u, max(max(left, right) + 2, ((up < -INF/2 && flag[u])? 0: up) + 1));
i++;
}
}
int main() {
scanf("%d", &T);
while(T--) {
scanf("%d%d", &n, &k);
for(int i = 1; i <= n; i++) {
flag[i] = f[i] = dp[i] = 0;
graph[i].clear();
}
for(int i = 1; i <= k; i++) {
int a;
scanf("%d", &a);
flag[a] = 1;
}
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
graph[u].push_back(v);
graph[v].push_back(u);
}
dfs1(1, 0);
res = INF;
dfs2(1, 0, -INF);
for(int i = 1; i <= n; i++) {
if(dp[i] < res) res = dp[i];
}
printf("%d\n", res);
}
return 0;
}