树形dp
设 f[i] 是以 i 为起点能抓到蝴蝶的最大数量, 那么 f[1] 就是答案
假设现在处于点 x , 第一步肯定是选一个子节点搜, 第二步有两种情况, 继续往下搜和返回父节点再搜
情况 1 : 继续往下搜
则 f[x]=max(f[x],a[x]+a[yi]+∑yi是x的子节点(f[yi]−a[yi])
情况 2 : 返回父节点
这种情况只会在 t[yi]=3 时发生, 因为如果不成立走这步一定是亏的
假设现在处于点 x, 先选了个点 z, 然后回到 x, 再选择另一个点y(t[y]=3)
此时子树z的贡献就是a[z]+∑ki是z的子节点(f[ki]−a[ki])
由于n的范围是1e5, 状态转移需要O(1),所以需要预处理所有的a[x]+∑yi是x的子节点(f[yi]−a[yi])
令g[x]=a[x]+∑yi是x的子节点(f[yi]−a[yi])
所以情况 1 可以写成 f[x]=max(f[x],a[yi]+g[x])
那么情况 2 中子树 z 的贡献就是 g[z]
此时 f[x]=max(f[x],g[x]−(f[z]−a[z])+g[z]+a[y])
其中 z 可以是 x 的任意一个子节点, y 要满足 t[y]=3
由于点 y 和 z 不能是同一个点, 只要记录 g[z]−(f[z]−a[z]) 的最大值和次大值即可
代码实现
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define rep(i, l, r) for(int i = l; i <= (int)r; i ++)
#define per(i, r, l) for(int i = r; i >= (int)l; i --)
#define debug(a) cout << #a << " = " << a << '\n';
#define PII pair<int, int>
const int N = 1e6 + 6;
vector<int> e[N];
int a[N], t[N], f[N], g[N];
void dfs(int x, int fa) {
int d1 = -1e9, d2 = -1e9;
g[x] = a[x];
for(auto y : e[x]) {
if(y == fa) continue;
dfs(y, x);
g[x] += f[y] - a[y];
int temp = g[y] - (f[y] - a[y]);
if(temp >= d1) d2 = d1, d1 = temp;
else if(temp > d2) d2 = temp;
}
f[x] = g[x];
for(auto y : e[x]) {
if(y == fa) continue;
f[x] = max(f[x], g[x] + a[y]);
if(t[y] == 3) {
if(g[y] - f[y] + a[y] == d1) f[x] = max(f[x], g[x] + a[y] + d2);
else f[x] = max(f[x], g[x] + a[y] + d1);
}
}
}
void solve() {
int n;
cin >> n;
rep(i, 1, n) {
cin >> a[i];
f[i] = g[i] = 0;
e[i].clear();
}
rep(i, 1, n) cin >> t[i];
rep(i, 1, n - 1) {
int u, v;
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1, -1);
cout << f[1] << '\n';
}
main() {
ios::sync_with_stdio(false);
cin.tie(0);
int t;
cin >> t;
while(t --) {
solve();
}
return 0;
}