题目描述
难度分:2470
输入n(1≤n≤2×105)和长为n的数组a(1≤a[i]≤n)。
然后输入一棵无向树的n−1条边,节点编号从1到n。节点i的颜色是a[i]。
定义f(c)=包含颜色c的简单路径的数目。注:只有1个点也算路径。
输出f(1),f(2),…,f(n)。
输入样例1
3
1 2 1
1 2
2 3
输出样例1
5
4
0
输入样例2
1
1
输出样例2
1
输入样例3
2
1 2
1 2
输出样例3
2
2
输入样例4
5
1 2 3 4 5
1 2
2 3
3 4
3 5
输出样例4
5
8
10
5
5
输入样例5
8
2 7 2 5 4 1 7 5
3 1
1 2
2 7
4 5
5 6
6 8
7 8
输出样例5
18
15
0
14
23
0
23
0
算法
逆向思维
正难则反,计算不包含颜色c的简单路径数。所有路径数减去不包含颜色c的路径数就是答案。
去掉颜色c节点后,树分成了若干连通块。对于大小为m的连通块,其中有m×(m+1)2条简单路径。接下来的问题就在于如何快速计算各个连通块的大小?
如图,去掉粉色节点后,考虑包含节点y的连通块,它的大小等于:树y的大小,减去子树z1,z2,z3的大小之和。
如何计算子树z1,z2,z3的大小之和?额外用一个数组size[c]记录以颜色c为根的子树的大小之和。
但这样还有问题,考虑x更上面的节点,这样做z1,z2,z3的子树大小就会和x的子树大小相加。为了避免重复累加子树大小,直接用子树x的大小覆盖子树z1,z2,z3的大小之和。
在实现上,于dfs(x)的开头用临时变量old记录size[c];在dfs(x)的末尾,覆盖size[c]=old+子树x的大小。
复杂度分析
时间复杂度
DFS
遍历整棵树就能预处理出计算答案所需要的信息,接下来遍历i∈[1,n],对每个i节点O(1)计算答案即可。因此,整个算法的时间复杂度为O(n)。
空间复杂度
size数组和ans数组都是线性空间消耗;整棵树的邻接表空间消耗也是O(n);在DFS
的过程中,如果整棵树退化成链,递归深度就会达到O(n),空间复杂度仍然是O(n)。因此,整个算法的额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 200010;
int n, c[N];
long long ans[N], Size[N];
vector<int> graph[N];
int dfs(int u, int fa) {
int old = Size[c[u]];
int sz = 1;
for(int v: graph[u]) {
if(v == fa) continue;
Size[c[u]] = 0;
int szW = dfs(v, u);
sz += szW;
int m = szW - Size[c[u]];
ans[c[u]] += m*(m + 1LL)/2;
}
Size[c[u]] = old + sz;
return sz;
}
int main() {
scanf("%d", &n);
for(int i = 1; i <= n; i++) {
scanf("%d", &c[i]);
graph[i].clear();
Size[i] = ans[i] = 0;
}
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
graph[u].push_back(v);
graph[v].push_back(u);
}
dfs(1, 0);
for(int i = 1; i <= n; i++) {
int m = n - Size[i];
printf("%lld\n", n*(n + 1LL)/2 - ans[i] - m*(m + 1LL)/2);
}
return 0;
}