题目描述
难度分:2600
输入n(2≤n≤105)和长为n的数组a(1≤a[i]≤109),表示每个节点的点权。
然后输入一棵树的n−1条边,节点编号从1开始。
从这棵树中,选出两条不相交的路径,也就是没有节点会同时出现在两条路径中。输出这两条路径的节点点权之和的最大值。
输入样例1
9
1 2 3 4 5 6 7 8 9
1 2
1 3
1 4
1 5
1 6
1 7
1 8
1 9
输出样例1
25
输入样例2
2
20 10
1 2
输出样例2
30
算法
树形DP
这个状态定义和转移也太复杂了,又是学习的一天。用w[u]表示节点u的点权。
状态定义
f[u][0]表示以u为根的子树中,两条不相交链的最大点权和。f[u][1]表示以u为根的子树中,点权和最大的一条链。g[u]表示以u为根的子树中,u到叶子节点+另外一条链的最大点权和。h[u]表示以u为根的子树中,u的儿子节点son中f[son][1]的最大值,这个h主要就是用来保证两条链不相交。down[u]表示从u到叶子节点的最大点权和。
对此无根树随便定一个根节点1,则在这个状态定义下,答案就应该是f[1][0]。
状态转移
对于某个节点u,f[u][0]的转移来源有:
- f[v][0],其中v是u的一个子节点,这表示直接继承v子树的答案。
- f[u][1]+f[v][1],表示从子树u中点权和最大的链+子树v中点权和最大的链。
- down[u]+g[v],表示子树v内到叶子节点的链往上延伸穿过节点u形成的一条链+子树v中一条与之不相交的链。
- down[v]+g[u],与上一个转移类似,表示子树v到某个叶子节点的链往上延伸穿过节点u形成一条链+子树u中一条与之不相交的链。
以上4种情况选较大的转移。
g[u]的转移来源有:
- w[u]+g[v],表示直接将叶子节点到v的链延伸到u形成一条链+子树v中一条与之不相交的链。
- down[u]+f[v][1],表示子树v中点权和最大的一条链+子树u中从u到叶子节点点权和最大的一条链(两条链不相交,因为此时down[u]还没有更新完全,只有v节点的信息是全的,得不到u穿过v达到某个叶子节点的路径信息)。
- down[v]+w[u]+h[u],表示让v到某个叶子节点的路径往上延伸到u+u的子节点子树中最大点权和的那条链。
以上3种情况选较大的转移。
h[u]比较好维护,在遍历子节点v时维护就行,h[u]=maxvf[v][1]。
down[u]同理,在遍历子节点v时维护,down[u]=maxvdown[v]+w[u]。
复杂度分析
时间复杂度
状态数量是O(n)(准确来说是4n),单次转移的时间复杂度为O(1)。整个算法的时间复杂度就是遍历一遍树的时间复杂度,因为树的边也是O(n)级别(准确来说是n−1),所以时间复杂度为O(n)。
空间复杂度
除了输入的点权数组w,树的邻接表空间复杂度为O(n+m),其中m是边数,由于是树所以m=n−1,邻接表的空间复杂度为O(n)。DP
数组的空间复杂度为O(4n)=O(n)。因此,整个算法的额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 100010, INF = 0x3f3f3f3f;
int n, w[N];
vector<int> graph[N];
LL f[N][2], g[N], h[N], down[N];
void dfs(int u, int fa) {
f[u][0] = f[u][1] = g[u] = down[u] = w[u];
h[u] = 0;
for(int v: graph[u]) {
if(v == fa) continue;
dfs(v, u);
f[u][0] = max(f[u][0], f[v][0]);
f[u][0] = max(f[u][0], f[u][1] + f[v][1]);
f[u][0] = max(f[u][0], down[u] + g[v]);
f[u][0] = max(f[u][0], g[u] + down[v]);
f[u][1] = max(f[u][1], f[v][1]);
f[u][1] = max(f[u][1], down[u] + down[v]);
g[u] = max(g[u], w[u] + g[v]);
g[u] = max(g[u], down[u] + f[v][1]);
g[u] = max(g[u], down[v] + w[u] + h[u]);
h[u] = max(h[u], f[v][1]);
down[u] = max(down[u], down[v] + w[u]);
}
}
int main() {
scanf("%d", &n);
for(int i = 1; i <= n; i++) {
scanf("%d", &w[i]);
graph[i].clear();
}
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
graph[u].push_back(v);
graph[v].push_back(u);
}
dfs(1, 0);
printf("%lld\n", f[1][0]);
return 0;
}