题目大意
给定一棵无根树,函数$S(x,y)=\sum\limits_{v\in V}(W(v)*min(d(x,v),d(y,v)))$,其中$V$是点集,$W(v)$是$v$点的权值,$d(x,y)$表示树上两点之间的距离。
求当$x$和$y$取多少时,$S(x,y)$能取到最小值,并输出这个最小值。
Solution
首先需要统一的是,$W(v)$是一定的,所以如果只有$1$个点时,就必定取重心。
然而题目中是$2$个点,但是$S(x,y)$的含义还是和重心所差无几。一旦选中了$x,y$,那么整棵树的点集就分成了到$x$点和到$y$点两部分,而这两部分一定是内部互相连通的,所以开始的时候不妨把它看做是切成$2$棵子树,然后对于每棵子树都求一下重心,作为$x,y$。
读题后发现,树的高度最大是$100$,利用这一特性,如果我们可以在$O(h)$的范围内求解任意子树的重心,那么这题就得解了。我们可以枚举断边,然后对于每次断边,分别求出两棵子树的重心,然后更新答案。
-
那么可以在$O(h)$内求解重心吗?
-
可以的!
首先设$f(u)$表示整棵树中所有节点到$u$的距离和(当然要乘上点权),$siz[u]$表示$u$的子树大小,$v$为$u$的一个儿子,则可以递推:
$$f(v)=f(u)+siz[root]-siz[v]-siz[v]$$
解释:当点下移了一格后,除了$v$的子树内的点,其它都多走了一格,所以要加上$siz[root]-siz[v]$,但是$v$为根的子树内所有点可以少走一格,所以减去$siz[v]$。
那么根据重心的性质,如果$v$比$u$更适合当重心,那么有$f(v)<f(u)$
即
$$f(u)+siz[root]-2*siz[v]<f(u)$$
可知,当
$$siz[root]<2*siz[v]$$
时,才有向下走的必要。
又由于,重心中最大的子树不会超过总节点数的一半,所以满足$siz[root]<2*siz[v]$的节点数不会超过$1$个,所以从根开始,每次走子树最大的儿子,直到不满足不等式,之前$f(u)$最大的那个点就是重心。
综上,可以得出这题的解法:
- $stp1\quad$ 先$dfs$预处理出以下信息:
子树大小$siz[u]$,深度$dep[u]$,子树中所有点到当前点的距离和$dp[u]$,每个节点的父亲$fa[u]$。
每个节点所有儿子中子树最大的编号$id1[u]$与次大的编号$id2[u]$。(为什么要次大,后面会讲)。 - $stp2\quad$ 然后枚举切断的边,记断边相连的深度大的那个点是$cut$,然后从$cut$开始,向父亲迭代,并更新父亲的$siz[u]$的值。
- $stp3\quad$ 然后计算答案,可以用递归记录信息。
int ret=inf,ct;
void dfs1(int x,int cur,int S){//x:当前节点 cur:当前答案 S:切断后x所在树的总的大小
ret=min(ret,cur);
int s=id1[x];
if(s==ct||siz[s]<siz[id2[x]]) s=id2[x];//如果最大的儿子被破坏或者被切断,则退取次大的儿子
if(!s) return;
if(S<2*siz[s])
dfs1(s,cur+S-2*siz[s],S);//只有满足条件才有必要递归,并利用上面的递推式得出答案
}
ret=inf;dfs1(1,dp[1]-dp[i]-siz[i]*dep[i],siz[1]);A=ret;
//处理上方的那棵树,此时f[1]=dp[1]-dp[i]-siz[i]*dep[i],不理解的先看着
ret=inf;dfs1(i,dp[i],siz[i]);B=ret;
//处理下方的
- $stp4\quad$ 更新历史最优值,再次从切点开始向父亲迭代,回溯,调到$stp2$
- $stp5\quad$ 输出答案,结束。
对于式子$f[1]=dp[1]-dp[i]-siz[i]*dep[i]$,就是如果只有$f[1]=dp[1]-dp[i]$容易理解少考虑了子树$i$的深度,那么全体乘上深度即可。
Code
#include<bits/stdc++.h>
#define ll long long
#define inf 1<<30
using namespace std;
const int MAXN=5e4+10;
vector<int> e[MAXN];
int w[MAXN];
int dp[MAXN],f[MAXN],dep[MAXN];
int siz[MAXN],id1[MAXN],id2[MAXN];
void dfs(int x,int fa){
f[x]=fa;
siz[x]=w[x];int s1=0,s2=0;
for(int i=0;i<e[x].size();i++){
int s=e[x][i];
if(s==fa) continue;
dep[s]=dep[x]+1;
dfs(s,x);
siz[x]+=siz[s];
dp[x]+=dp[s]+siz[s];
if(siz[s]>s1){
s2=s1,s1=siz[s];
id2[x]=id1[x],id1[x]=s;
}
else if(siz[s]>s2)
s2=siz[s],id2[x]=s;
}
}
int ret=inf,ct;
void dfs1(int x,int cur,int S){
ret=min(ret,cur);
int s=id1[x];
if(s==ct||siz[s]<siz[id2[x]]) s=id2[x];
if(!s) return;
if(S<2*siz[s])
dfs1(s,cur+S-2*siz[s],S);
}
int main()
{
int n,x,y;
scanf("%d",&n);
for(int i=1;i<n;i++){
scanf("%d%d",&x,&y);
e[x].push_back(y);
e[y].push_back(x);
}
for(int i=1;i<=n;i++)
scanf("%d",&w[i]);
dfs(1,-1);
int ans=inf;
for(int i=2;i<=n;i++){
for(int j=f[i];j!=-1;j=f[j]) siz[j]-=siz[i];
int A,B;ct=i;
ret=inf;dfs1(1,dp[1]-dp[i]-siz[i]*dep[i],siz[1]);A=ret;
ret=inf;dfs1(i,dp[i],siz[i]);B=ret;
ans=min(ans,A+B);
for(int j=f[i];j!=-1;j=f[j]) siz[j]+=siz[i];
}printf("%d\n",ans);
}