AcWing 352. java同学写的题解
原题链接
困难
作者:
季之秋
,
2021-05-20 19:46:47
,
所有人可见
,
阅读 256
/*
(1)每一条非树边都构成一个环,砍掉环中任一条树边之后需要再砍掉一条非树边,
所以环中每条边d[i] ++, d[i]表示砍掉这条边之后需要砍掉几条非树边
(2)树上求差分: d[a]+c d[b]+c, a和b的最短公共祖先 d[p] - 2*c 。 每个点的值就是所有子节点的总和
*/
import java.util.*;
public class Main{
static int N = 100010, M = 300010;
static int n , m, res;
static int e[] = new int[M], ne[] = new int[M], h[] = new int[N], idx;
static int d[] = new int[N];
static int depth[] = new int[N], f[][] = new int[N][18]; // 求lca
public static void main(String[] args){
Scanner sc = new Scanner(System.in);
n = sc.nextInt();
m = sc.nextInt();
Arrays.fill(h, -1);
for(int i = 0; i < n-1; i ++){
int a = sc.nextInt();
int b = sc.nextInt();
add(a, b); add(b, a);
}
bfs(); // 初始化 倍增lca
for(int i = 0 ;i < m; i ++){
int a = sc.nextInt();
int b = sc.nextInt();
int p = lca(a, b);
d[a] ++; d[b] ++ ; d[p] -= 2; // 差分->需要砍多少条非树边
}
dfs(1, -1); // 前缀和
System.out.println(res);
}
static int dfs(int u, int fa){
int ans = d[u]; // 返回自己节点的值
for(int i = h[u]; i != -1; i = ne[i]){
int j = e[i];
if(j != fa){
int s = dfs(j, u);
if(s == 0) res += m;
else if(s == 1) res += 1;
ans += s;
}
}
return ans;
}
static int lca(int a, int b){
if(depth[a] < depth[b]){
int c = a; a = b; b = c;
}
for(int i = 16; i >= 0; i --){
if(depth[f[a][i]] >= depth[b]){
a = f[a][i];
}
}
if(a == b) return a;
for(int i = 16; i >= 0; i --){
if(f[a][i] != f[b][i]){
a = f[a][i];
b = f[b][i];
}
}
return f[a][0];
}
static void bfs(){
Queue<Integer> q =new LinkedList();
q.add(1);
Arrays.fill(depth, 0x3f3f3f3f);
depth[0] = 0; depth[1] = 1;
while(!q.isEmpty()){
int t = q.poll();
for(int i = h[t]; i != -1; i = ne[i]){
int j = e[i];
if(depth[j] > depth[t] + 1){
depth[j] = depth[t] + 1;
q.add(j);
f[j][0] = t;
for(int k = 1; k <= 16; k ++){
f[j][k] = f[f[j][k-1]][k-1];
}
}
}
}
}
static void add(int a, int b){
e[idx] = b; ne[idx] = h[a]; h[a] = idx ++;
}
}