算法分析
找树的中心类似于找圆的半径
-
d1[u]
:表示u
点向下走的最大长度 -
d2[u]
:表示u
点向下走的次大长度 -
up[u]
:表示u
点向上走的最大长度 -
假设当前点是
j
点,上一结点是u
点,计算经过j
点的到所有点的最长距离res = max(d1[j],up[j])
-
当
u
点往下走的最大长度经过j
,则需要用到次大长度
up[j] = Math.max(up[u], d2[u]) + w[i]
-
当
u
点往下走的最大长度不经过j
,则
up[j] = Math.max(up[u], d1[u]) + w[i]
-
时间复杂度 $O(n)$
参考文献
算法提高课
Java 代码
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.Scanner;
public class Main {
static int N = 10010;
static int M = N * 2;
static int INF = 0x3f3f3f3f;
static int[] h = new int[N];
static int[] e = new int[M];
static int[] ne = new int[M];
static int[] w = new int[M];
static int idx = 0;
static int[] d1 = new int[N];
static int[] d2 = new int[N];
static int[] up = new int[N];
static int[] son1 = new int[N];//son1[i] = j表示 i的大儿子是j
static int[] son2 = new int[N];//son2[i] = j表示 i的二儿子是j
static void add(int a,int b,int c)
{
e[idx] = b;
w[idx] = c;
ne[idx] = h[a];
h[a] = idx ++;
}
//找到u点往下走的最大长度
static int dfs_down(int u,int father)
{
d1[u] = -INF;
d2[u] = -INF;
for(int i = h[u];i != -1;i = ne[i])
{
int j = e[i];
if(j == father) continue;
int d = dfs_down(j,u) + w[i];
if(d > d1[u])
{
d2[u] = d1[u]; d1[u] = d;
son2[u] = son1[u]; son1[u] = j;
}
else if(d > d2[u])
{
d2[u] = d;
son2[u] = j;
}
}
if(d1[u] == -INF) {d1[u] = 0; d2[u] = 0;}
return d1[u];
}
static void dfs_up(int u,int father)
{
for(int i = h[u];i != -1;i = ne[i])
{
int j = e[i];
if(j == father) continue;
if(son1[u] == j) up[j] = Math.max(up[u], d2[u]) + w[i];
else up[j] = Math.max(up[u], d1[u]) + w[i];
dfs_up(j,u);
}
}
public static void main(String[] args) throws IOException {
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(reader.readLine().trim());
Arrays.fill(h,-1);
for(int i = 0;i < n - 1;i ++)
{
String[] s1 = reader.readLine().split(" ");
int a = Integer.parseInt(s1[0]);
int b = Integer.parseInt(s1[1]);
int c = Integer.parseInt(s1[2]);
add(a,b,c);
add(b,a,c);
}
dfs_down(1,-1);
dfs_up(1,-1);
int res = INF;
for(int i = 1;i <= n;i ++) res = Math.min(res, Math.max(d1[i], up[i]));
System.out.println(res);
}
}