题目描述
难度分:1700
输入T(≤104)表示T组数据。所有数据的n之和≤2×105。
每组数据输入n(2≤n≤2×105)、k(1≤k≤109)、c(1≤c≤109)。
然后输入一棵无向树的n−1条边(节点编号从1开始),每条边的边权都是k。
树的根节点是1。每次操作,你可以花费c,把树的根节点改成其邻居之一。定义
- 树的的得分为其高度(根节点到最远叶子的距离)。
- 在计算最终得分时,要从树的得分中,扣除操作的总花费。
输出最终得分的最大值。
输入样例
4
3 2 3
2 1
3 1
5 4 1
2 1
4 2
5 4
3 4
6 5 3
4 1
6 1
2 6
5 1
3 2
10 6 4
1 3
1 9
9 7
7 6
6 4
9 2
2 8
8 5
5 10
输出样例
2
12
17
32
算法
换根DP
感觉没昨天那题难,甚至可以说容易很多,算是比较容易看出来的换根树形DP
,因为题面中已经明确提到了换根操作。因此,先将根定为1,做一遍树形DP
。然后再做一遍DFS
进行换根,计算每个节点作为根时的得分,维护最大的得分。
树形DP
状态定义
f[u]表示以u为根的子树中,最远叶子节点到u的距离。
状态转移
设v为u的子节点,状态转移方程为f[u]=k+maxvf[v]。
在DFS
过程中再处理出一个depth数组,depth[u]表示节点u的深度(深度是在根为1的情况下计算得到)。
DFS
换根
然后跑dfs(u,fa,vx)换根,fa是u访问前最后一个访问的节点,防止走回头路。vx是u上面的所有节点到u的最大距离。初始情况下为dfs(1,0,0),维护max(vx,f[u])−depth[u]×c的最大值即可。
当从u转移到v时需要对vx进行更新,v上面节点到v的最长距离为max(maxv′f[v′]+k,vx)+k,其中v′是v以u为父节点的兄弟节点,它们要先走k的边权到u,然后从u上下来到v。为了快速求得maxv′f[v′]+k,需要对u的子节点做一个前后缀分解,维护子节点的前缀最大f[v]+k值和后缀最小f[v]+k值。
复杂度分析
时间复杂度
换根DP
的本质就是对整棵树做两遍DFS
,时间复杂度就是O(n+m),m是边数,n是节点数,而数的边数m=n−1,所以时间复杂度为O(n)。
空间复杂度
树的邻接表空间为O(n),DP
数组f和高度数组depth都是线性空间O(n),DFS
的过程中最大递归深度也可能达到O(n)。因此,整个算法的额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 200010;
int T, n;
vector<int> graph[N];
LL k, c, ans, f[N], depth[N];
LL dfs1(int u, int fa, LL d) {
depth[u] = d;
for(int v: graph[u]) {
if(v == fa) continue;
f[u] = max(f[u], k + dfs1(v, u, d + 1));
}
return f[u];
}
void dfs2(int u, int fa, LL vx) {
ans = max(ans, max(f[u], vx) - depth[u]*c);
vector<LL> child;
for(int v: graph[u]) {
if(v == fa) continue;
child.push_back(f[v] + k);
}
int sz = child.size();
vector<LL> pre(sz), suf(sz);
for(int i = 0; i < sz; i++) {
pre[i] = max(i? pre[i - 1]: 0, child[i]);
int j = sz - 1 - i;
suf[j] = max(j + 1 < sz? suf[j + 1]: 0, child[j]);
}
int i = 0;
for(int v: graph[u]) {
if(v == fa) continue;
dfs2(v, u, max(max(i? pre[i - 1]: 0, i + 1 < sz? suf[i + 1]: 0), vx) + k);
i++;
}
}
int main() {
scanf("%d", &T);
while(T--) {
scanf("%d%lld%lld", &n, &k, &c);
for(int i = 1; i <= n; i++) {
graph[i].clear();
f[i] = depth[i] = 0;
}
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
graph[u].push_back(v);
graph[v].push_back(u);
}
ans = dfs1(1, 0, 0);
dfs2(1, 0, 0);
printf("%lld\n", ans);
}
return 0;
}