题目描述
难度分:2048
输入n(2≤n≤2×105)和一棵无向树的n−1条边,节点编号从1到n。
定义f(i)为如下过程的方案数:
- 首先,在节点i写下数字1。
- 然后,选择一个与写下数字节点相邻的节点(没有写过数字),写下数字2。
继续,重复上述过程,依次写下数字3,4,…,n。
输出f(1),f(2),…,f(n)。
输入样例1
3
1 2
1 3
输出样例1
2
1
1
输入样例2
2
1 2
输出样例2
1
1
输入样例3
5
1 2
2 3
3 4
3 5
输出样例3
2
8
12
3
3
输入样例4
8
1 2
2 3
3 4
3 5
3 6
6 7
6 8
输出样例4
40
280
840
120
120
504
72
72
算法
换根DP
先考虑f(1)怎么算,如果从1开始DFS
遍历这棵树,按顺序记录遍历到的数字,会得到一个1~n的排列,总共有n! 种,其中肯定有不合法的。
比如,这个排列肯定要以1开头,所有不以1开头的排列都是不合法的。以1开头的排列个数是(n−1)!(剩下的n−1个数全排列),相当于把n!除以n。
同理,对于每棵子树v而言(对应着1~n排列中的一个子序列),不以v的标记数字开头的排列都是不合法的,同样要把方案数除以sz[v],即子树v的大小。
所以f(1)=n!sz[1]×sz[2]×…×sz[n]。
对于其他f(i),我们可以在f(1)的基础上,用换根DP
快速计算上式的分母。从节点u换到节点v,f[u]先除以sz[v],对于不在子树v中的节点,会形成一棵v为根节点朝上的子树,所以要乘以(n−size[v])。
上述过程中有除法同余,可以用费马小定理计算逆元。
复杂度分析
时间复杂度
对整棵树进行两次DFS
(一次求f[1],一次换根求f[u](u>1))就能得到答案。因此,整个算法的时间复杂度为O(n)。
空间复杂度
sz数组和f数组是线性的,为O(n)。预处理逆元需要三个辅助数组,均与n相关,空间消耗也是O(n)。因此,整个算法的额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 200010, MOD = 1e9 + 7;
vector<int> graph[N];
int n, sz[N], f[N];
LL inv[N], finv[N], fac[N];
void get_inv(int n) {
inv[0] = inv[1] = 1;
for(int i = 2; i <= n; i++) {
inv[i] = (MOD - MOD/i) * inv[MOD % i] % MOD;
}
finv[0] = finv[1] = fac[0] = fac[1] = 1;
for(int i = 2; i <= n; i++) {
fac[i] = fac[i - 1] * i % MOD;
finv[i] = finv[i - 1] * inv[i] % MOD;
}
}
int fast_power(int a, int k, int p) {
int res = 1 % p;
while(k) {
if(k & 1) res = (LL)res * a % p;
a = (LL)a * a % p;
k >>= 1;
}
return res;
}
void dfs(int u, int fa) {
sz[u] = 1;
for(int v: graph[u]) {
if(v == fa) continue;
dfs(v, u);
sz[u] += sz[v];
}
f[1] = 1LL * f[1] * sz[u] % MOD;
}
void reroot(int u, int fa) {
for(int v: graph[u]) {
if(v == fa) continue;
f[v] = 1LL * f[u] * fast_power(sz[v], MOD - 2, MOD) % MOD * (n - sz[v]) % MOD;
reroot(v, u);
}
}
int main() {
if(!inv[0]) get_inv(N - 10);
scanf("%d", &n);
for(int i = 1; i <= n; i++) {
graph[i].clear();
}
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
graph[u].push_back(v);
graph[v].push_back(u);
}
f[1] = 1;
dfs(1, 0);
reroot(1, 0);
for(int i = 1; i <= n; i++) {
printf("%d\n", fac[n] * fast_power(f[i], MOD - 2, MOD) % MOD);
}
return 0;
}