题目描述
难度分:1800
输入n(2≤n≤2×105)和长为n的数组a,只包含0和1。
然后输入一棵无向树的n−1条边,节点编号从1到n。
a[i]表示节点i的颜色,0表示黑色,1表示白色。
定义f(i)为包含节点i的连通块中,白色节点个数减去黑色节点个数的最大值。输出f(1),f(2),…,f(n),注意f(i)可能是负数,见样例2。
输入样例1
9
0 1 1 1 0 0 0 0 1
1 2
1 3
3 4
3 5
2 6
4 7
6 8
5 9
输出样例1
2 2 2 2 2 1 1 0 2
输入样例2
4
0 0 1 0
1 2
1 3
1 4
输出样例2
0 -1 1 -1
算法
换根DP
比较明显的一个换根DP
题目,因为每个节点都要求个最大值。
树形DP
状态定义
f[u]表示仅考虑以u为根的子树,在包括节点u的情况下,能找到的连通块“白-黑”最大值。
状态转移
如果u是白色,自增f[u],否则自减f[u]。然后遍历u的所有子节点v,如果f[v]>0就将其累加在f[u]上。状态转移方程为f[u]=delta+Σvmax(0,f[v]),其中a[u]=1时delta=1,a[u]=0时delta=−1。
DFS
换根
换根函数为dfs(u,vx),其中vx表示u上面的节点(不包括u)能够贡献的“白-黑”最大值。还是遍历u的所有子节点vi,对于一个子节点vi,需要从u开始往vx考虑,将vx作为根,那么vx对于vi节点就要更新为vx+Σi−1j=0max(0,f[vj])+Σcnt−1j=i+1max(0,f[j])+delta。
其中a[u]=1时delta=1,a[u]=0时delta=−1,cnt是u的子节点数目。此时如果vx<0,对于vi来说那还不如舍弃vi上面的贡献,所以它的vx应该是max(0,vx+Σi−1j=0max(0,f[vj])+Σcnt−1j=i+1max(0,f[j])+delta)。这两段求和操作可以通过前后缀分解先预处理出来,从而O(1)计算这两段求和。
节点u的最终答案就应该是g[u]=f[u]+vx。
复杂度分析
时间复杂度
本质上就是对整棵无向树进行了两次DFS
,因此算法的时间复杂度为O(n)。
空间复杂度
树的邻接表空间复杂度为O(n);f和g两个DP
数组空间复杂度为O(n);在DFS
的过程中需要前后缀分解,需要开辟的前后缀数组总长度在O(n)级别;DFS
在最差情况下递归深度在O(n)级别(树退化成一条链)。因此,整个算法的额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 200010;
int n, a[N], f[N], g[N];
vector<int> graph[N];
void dfs1(int u, int fa) {
f[u] = a[u]? 1: -1;
for(int v: graph[u]) {
if(v == fa) continue;
dfs1(v, u);
if(f[v] > 0) f[u] += f[v];
}
}
void dfs2(int u, int fa, int vx) {
g[u] = f[u] + vx;
vector<int> children;
for(int v: graph[u]) {
if(v == fa) continue;
children.push_back(f[v]);
}
int sz = children.size();
vector<int> pre(sz), suf(sz);
for(int i = 1, j = sz - 2; i < sz; i++, j--) {
pre[i] = pre[i - 1] + max(0, children[i - 1]);
suf[j] = suf[j + 1] + max(0, children[j + 1]);
}
int i = 0;
for(int v: graph[u]) {
if(v == fa) continue;
int nxt = vx + (a[u]? 1: -1) + pre[i] + suf[i];
dfs2(v, u, max(0, nxt));
i++;
}
}
int main() {
scanf("%d", &n);
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
graph[i].clear();
f[i] = 0; // f[i]表示以i为根的子树中,在包含i的情况下,连通块中“白-黑”的最大值
}
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
graph[u].push_back(v);
graph[v].push_back(u);
}
dfs1(1, 0);
dfs2(1, 0, 0);
for(int i = 1; i <= n; i++) {
printf("%d ", g[i]);
}
return 0;
}