题目描述
难度分:2600
输入T(≤5×104)表示T组数据。所有数据的n之和≤2×105。
每组数据输入n(2≤n≤105),k(0≤k≤n)和长为n的数组 a(1≤a[i]≤1000)。
然后输入一棵无向树的n−1条边,节点编号从1开始。根节点为1。节点i的点权为a[i]。
你可以执行如下操作至多k次:
- 选择一个没有操作过的节点v和一个整数x,其中x必须满足x是子树v中所有点权的公约数。然后把子树v中的每个点权都乘上x。
输出操作后,a[1]最大是多少。
输入样例
2
5 2
24 12 24 6 12
1 2
1 3
2 4
2 5
5 3
24 12 24 6 12
1 2
1 3
2 4
2 5
输出样例
288
576
算法
树形DP
比较容易看出来的是每次选择的x都需要是a[1]的因子,不然是没有用的,并且我们自底向上操作是最划算的。而每个节点只能操作一次,因此也比较容易能够想到树形DP
。
状态定义
f[u][c]表示以u为根的子树中所有节点的点权都为c的倍数时的最小操作次数。因此我们DFS完之后,直接遍历a[1]的所有因子factor,只要f[1][factor]≤k就说明所有节点都可以变成factor的倍数,最大的a[1]×factor就是答案。
状态转移
以当前节点为根的子树可操作也可不操作,有这两种策略:
- 如果不操作,那么f[u][c]=Σvf[v][c]其中v是u的子节点,并且要求c|a[u](b|a表示a能被b整除)。
- 如果操作,为了使得c|(a[u]×j)成立,我们找到最小的因子j就可以了(即最小的j满足c|j2,因为为了a[1]最大,我们每一步操作需要乘以最大因子,那么f[u][c]的转移来源就应该是f[v][j]。子节点的点权乘以j就会变成j2的倍数,那此时要想达成c|(a[u]×j)就还要满足j|a[u])。状态转移方程为f[u][c]=1+Σvf[v][j]。
复杂度分析
时间复杂度
DFS遍历一遍树的时间复杂度为O(n),每个节点需要枚举因子j,而[1,1000]范围内因子的最大个数为32,所以整体的时间复杂度为O(32n)。
空间复杂度
树的邻接表空间复杂度为O(n)。[1,1000]中每个数的因子存入数组d中,d[i]中为i的因子,空间复杂度可以看成是个常数较大的O(n)。预处理出每个数i满足i|j2的最小数g[i]=j存入到g数组中,这个空间可以开成常数的,但是我在实现的时候直接开了O(n)。瓶颈在于DP
矩阵fn×1000,它的空间复杂度为O(1000n),也是整个算法的额外空间复杂度。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int INF = 0x3f3f3f3f;
const int N = 100005;
vector<int> d[N];
int g[N]; // 最小的g[i]满足i|g[i]*g[i]
int a[N];
vector<int> graph[N];
int f[N][1005];
void dfs(int u, int fa) {
for(auto v: graph[u]) {
if(v == fa) continue;
dfs(v, u);
}
for(auto c: d[a[1]]) {
// 枚举根节点的因数
f[u][c] = INF; // 初始化为无效状态
// 第一种转移
if(a[u] % c == 0) {
int t = 0;
for(auto v: graph[u]) {
if(v == fa) continue;
if(f[v][c] == INF) {
t = INF;
break;
}
t += f[v][c];
}
f[u][c] = min(f[u][c], t);
}
if(u == 1) {
// 根节点没有第二种转移
f[u][c]++; // 根节点再进行一次操作
continue;
}
// 第二种转移
if(a[u] % g[c] == 0) {
int t = 1;
for(auto v: graph[u]) {
if(v == fa) continue;
if(f[v][g[c]] == INF) {
t = INF;
break;
}
t += f[v][g[c]];
}
f[u][c] = min(f[u][c], t);
}
}
}
int main() {
// 预处理[1,1000]内每个数的因数
for(int i = 1; i <= 1000; i++) {
for(int j = i; j <= 1000; j += i) {
d[j].push_back(i);
}
for(int j = 1;; j++) {
if(j * j % i == 0) {
g[i] = j;
break;
}
}
}
int T;
scanf("%d", &T);
while(T--) {
int n, k;
scanf("%d%d", &n, &k);
for(int i = 1; i <= n; i++) graph[i].clear();
for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
graph[u].push_back(v);
graph[v].push_back(u);
}
dfs(1, 0);
int ans = a[1];
for(auto factor: d[a[1]]) {
if(f[1][factor] <= k) ans = max(ans, a[1]*factor);
}
printf("%d\n", ans);
}
return 0;
}