题目描述
难度分:1900
输入n(1≤n≤3×105)和长为n的数组a(−109≤a[i]≤109) 表示树上每个点的点权,然后输入这棵树的n−1条边(节点编号从1开始)。
执行如下操作恰好一次:
选一个点作为根节点,根节点的点权不变,它的儿子的点权增加1,其余点的点权增加2。
最小化这棵树的最大点权,并输出。
输入样例1
5
1 2 3 4 5
1 2
2 3
3 4
4 5
输出样例1
5
输入样例2
7
38 -29 87 93 39 28 -55
1 2
2 5
3 2
2 4
1 7
7 6
输出样例2
93
输入样例3
5
1 2 7 6 7
1 5
5 3
3 4
2 4
输出样例3
8
算法
换根DP
这棵树是无根树,可以先定一个根(比如节点1)DFS
遍历整棵树求得以1为根时所有节点的信息。然后再DFS
遍历一遍进行状态转移,计算以每个节点为根时的答案,最后所有答案的最小值就是本题的答案。
第一遍DFS
进行传统的树形DP
,求如下信息f:
状态定义
f[u]表示以u为根的子树中除自己之外的节点最大点权。
状态转移
设v为u的子节点,则有状态转移方程f[u]=maxv(a[v],f[v])。
换根
接下来再进行一遍DFS
进行状态转移,定义递归函数dfs2(u,fa,vx),其中u为当前遍历到的节点,fa为u的父节点,vx为整棵树除了以fa为根节点的子树之外,所有节点的点权最大值。
对于一个节点u,它的答案应该是max(a[u],a[fa]+1,vx+2,maxv(a[v]+1,f[v]+2)),其中v是u的子节点。当继续往v递归时就需要更新属于v的vx,它的值应该是max(vx,a[fa],maxv′(a[v′],f[v′])),其中v′是u的所有子节点中除了v之外的子节点。实现的时候有一些细节和边界条件需要注意,详见代码。
复杂度分析
时间复杂度
进行两次DFS
,时间复杂度是线性的,因此算法的复杂度为O(n)。
空间复杂度
DP
数组、图的邻接表,以及递归深度的空间消耗都是O(n)的,因此额外空间复杂度也为O(n)。
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 300010, INF = 0x3f3f3f3f;
int n, ans, a[N], f[N];
vector<int> graph[N];
void dfs1(int u, int fa) {
// f[u]是以u为根的子树中,不包括u节点自身的最大点权
for(int v: graph[u]) {
if(v == fa) continue;
dfs1(v, u);
f[u] = max(f[u], max(a[v], f[v]));
}
}
void dfs2(int u, int fa, int vx){
int res = max({a[u], a[fa] + 1, vx + 2});
// 先预处理出u的所有子节点
vector<int> children;
for(int v: graph[u]) {
if(v == fa) continue;
children.push_back(v);
}
int sz = children.size();
// 处理前后缀最值,便于快速求出u除当前子节点v之外的最大值
vector<int> left(sz, -INF), right(sz, -INF);
for(int i = 0; i < sz; i++) {
int v1 = children[i];
left[i] = max({i >= 1? left[i - 1]: -INF, a[v1], f[v1]});
int j = sz - 1 - i;
int v2 = children[j];
right[j] = max({j + 1 < sz? right[j + 1]: -INF, a[v2], f[v2]});
}
vx = max(vx, a[fa]);
for(int i = 0; i < sz; i++) {
int v = children[i];
int maxL = i >= 1? left[i - 1]: -INF;
int maxR = i + 1 < sz? right[i + 1]: -INF;
dfs2(v, u, max({vx, maxL, maxR}));
res = max(res, max(a[v] + 1, f[v] + 2));
}
ans = min(ans, res);
}
int main() {
scanf("%d", &n);
memset(a, -0x3f, sizeof(a));
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
graph[i].clear();
}
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
graph[u].push_back(v);
graph[v].push_back(u);
}
memset(f, -0x3f, sizeof(f));
dfs1(1, 0);
ans = INF;
dfs2(1, 0, -INF);
printf("%d\n", ans);
return 0;
}