对于这道题目,首先应该判断出来每个点的取值应该是取值范围的极值,下面来简单证明一下:假设点$u$的取值是$d$且非极值,那么对于它的任意一个临点$v$,记$v$的取值范围是$[l,r]$,那么$d$关于$[l,r]$的位置关系应该有三种:
- 当$d$位于$[l,r]$的左边,那么此时的最优值应该是$l-d$,但是我们知道$d$可以取到的更小的值使得最终的值更优;
- 当$d$位于$[l,r]$的右边时同理;
- $d$位于$[l,r]$内部时,$d$可以离$l$更近,此时最优值应该是$r-d$,同样我们可以让$d$取到更小值使最终的值更优;同样$d$离$r$更近同理。
那么问题就转换为,给定一棵树,树上的每个点有两个取值,每棵树的值为所有临点取值之差的和。如果直接暴力枚举每个点的取值,那么时间复杂度应该是$O(2^{n}$),显然不合理,看下能否优化。
不难发现,每次每个点没必要从前面所有的状态转移而来,只要择优从更大的值转移而来即可(最优化剪枝)。
- 状态定义:$f[u][0]$表示以节点$u$为根节点的树且$u$选左端点的最大值,$f[u][1]$表示选右端点。
- 状态划分:对于节点$u$,以及它的子节点$v$,当$u$选择任意一个端点时,$v$都有左右两个端点可以选择,只要从转移后更大的那个点转移而来即可。
- 状态转移:$f[u][0]=max(f[v][1]+abs(l[u]-r[v],f[v][0]+abs(l[u]-l[v])$,$f[u][1]$同理。
#include <iostream>
#include <cstring>
#include <set>
#include <map>
#include <vector>
#include <queue>
#include <stack>
#include <algorithm>
#include <cmath>
#include <unordered_map>
#define bug1(g) cout<<"test: "<<g<<endl
#define bug2(g , i) cout<<"test: "<<g<<" "<<i<<endl
#define bug3(g , i , k) cout<<"test: "<<g<<" "<<i<<" "<<k<<endl
#define bug4(a , g , i , k) cout<<"test: "<<a<<" "<<g<<" "<<i<<" "<<k<<endl
#define INF 0x3f3f3f3f
#define fi first
#define se second
#define met(a , b) memset(a , b , sizeof a);
#define pb push_back
using namespace std;
typedef long long LL ;
typedef pair<LL , LL> PII;
const int N = 100010 , M = 1000010;
int e[N * 2] , ne[N * 2] , h[N] , idx;
int n;
PII t[N];
LL f[M][2];
LL v[N];
void add(int a, int b)
{
e[idx] = b , ne[idx] = h[a] , h[a] = idx++;
}
void dfs(int u , int fa)
{
for(int i = h[u] ; ~i ; i = ne[i])
{
int j = e[i];
if(j == fa) continue;
dfs(j , u);
f[u][1] += max(f[j][1] + abs(t[u].se - t[j].se) , f[j][0] + abs(t[u].se - t[j].fi));
f[u][0] += max(f[j][1] + abs(t[u].fi - t[j].se) , f[j][0] + abs(t[u].fi - t[j].fi));
//cout << f[u][1] << ' ' << f[u][0] << endl;
}
}
void solve()
{
memset(h , -1 , sizeof h);
met(f , 0);
met(v , 0);
idx = 0;
cin >> n;
for(int i = 1 ; i <= n ; i++)
{
int l , r;
scanf("%d%d" , &l , &r);
t[i] = {l , r};
}
for(int i = 1 ; i < n ; i++)
{
int a , b;
cin >> a >> b;
add(a , b) , add(b , a);
}
dfs(1 , 1);//以任意一个点为根节点,记录一个父节点,避免往回递归
printf("%lld\n" , max(f[1][0] , f[1][1]));
}
int main()
{
int T = 1;
cin >> T;
for(int turn = 1 ; turn <= T ; turn++)
solve();
}