lca模板
P3379
import java.io.*;
import java.util.*;
public class Main{
static int N = 500010,M = 2*N;
static int[] h = new int[N],e = new int[M],ne = new int[M];
static int[] depth = new int[N];
static int[][] fa = new int[N][20];
static int idx = 0,INF = 0x3f3f3f3f,n,m;
static BufferedReader br
= new BufferedReader(new InputStreamReader(System.in));
static StreamTokenizer sc = new StreamTokenizer(br);
static BufferedWriter bw
= new BufferedWriter(new OutputStreamWriter(System.out));
public static void main(String[] args)throws IOException{
n = nextRead();
m = nextRead();
Arrays.fill(h,-1);
int root = nextRead();
for(int i = 1;i<=n-1;i++){
int a = nextRead();
int b = nextRead();
add(a,b);
add(b,a);
}
bfs(root);
for(int i = 1;i<=m;i++){
int a = nextRead();
int b = nextRead();
int anc = lca(a,b);
bw.write(anc+"\n");
}
bw.flush();
bw.close();
br.close();
}
public static int lca(int a,int b){
if(depth[a]<depth[b]){
int c = a;
a = b;
b = c;
}
for(int k = 19;k>=0;k--)
if(depth[fa[a][k]]>=depth[b])
a = fa[a][k];
if(a == b) return a;
for(int k = 19;k>=0;k--)
if(fa[a][k]!=fa[b][k]){
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
public static void bfs(int root){
Arrays.fill(depth,INF);
depth[0] = 0;
depth[root] = 1;
Queue<Integer> q = new LinkedList<>();
q.add(root);
while(!q.isEmpty()){
int t = q.poll();
for(int i = h[t];i!=-1;i = ne[i]){
int j = e[i];
if(depth[j]>depth[t]+1){
depth[j] = depth[t]+1;
q.add(j);
fa[j][0] = t;
for(int k = 1;k<=19;k++)
fa[j][k] = fa[fa[j][k-1]][k-1];
}
}
}
}
public static void add(int a,int b){
e[idx] = b;
ne[idx] = h[a];
h[a] = idx++;
}
public static int nextRead()throws IOException{
sc.nextToken();
return (int)sc.nval;
}
}
P5836 [USACO19DEC] Milk Visits S
分析:
- 题目的要求是求两点间的路径上是否存在某个值(对应的牛)
- 通过lca可以快速确定两点间的祖先
- 假设要求的是两点x,y,由于最多只有两种牛,可以使用一个二维数组sum[i][2]
记录从根节点到i的这条路径上的”点”一共有多少牛,可以在预处理depth和fa的时候预处理出sum
假定要判断是否是A这种牛,那么公式就有cnt = sum[x][A]+sum[y][A]-sum[lca(x,y)][A]-sum[fa[lca(x,y)][0]][A]
import java.io.*;
import java.util.*;
public class Main{
static int N = 100010,M = 2*N;
static int[] h = new int[N],e = new int[M],ne = new int[M];
static int[] depth = new int[N],val = new int[N];
static int[][] fa = new int[N][20],sum = new int[N][2];
static int idx = 0,INF = 0x3f3f3f3f,n,m;
static BufferedReader br
= new BufferedReader(new InputStreamReader(System.in));
static StreamTokenizer sc = new StreamTokenizer(br);
static BufferedWriter bw
= new BufferedWriter(new OutputStreamWriter(System.out));
public static void main(String[] args)throws IOException{
String[] s = br.readLine().split(" ");
n = Integer.parseInt(s[0]);
m = Integer.parseInt(s[1]);
Arrays.fill(h,-1);
String ss = br.readLine();
for(int i = 1;i<=n;i++)
val[i] = ss.charAt(i-1)-'G';
for(int i = 1;i<=n-1;i++){
s = br.readLine().split(" ");
int a = Integer.parseInt(s[0]);
int b = Integer.parseInt(s[1]);
add(a,b);
add(b,a);
}
bfs();
// for(int i = 1;i<=n;i++)
// System.out.println(Arrays.toString(sum[i]));
for(int i = 1;i<=m;i++){
s = br.readLine().split(" ");
int a = Integer.parseInt(s[0]);
int b = Integer.parseInt(s[1]);
int c = s[2].charAt(0)-'G';
int anc = lca(a,b);
int ans = sum[a][c]+sum[b][c]-sum[anc][c]-sum[fa[anc][0]][c];
if(ans>0)
bw.write("1");
else bw.write("0");
}
bw.flush();
bw.close();
br.close();
}
public static int lca(int a,int b){
if(depth[a]<depth[b]){
int c = a;
a = b;
b = c;
}
for(int k = 19;k>=0;k--)
if(depth[fa[a][k]]>=depth[b])
a = fa[a][k];
if(a == b) return a;
for(int k = 19;k>=0;k--)
if(fa[a][k]!=fa[b][k]){
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
public static void bfs(){
Arrays.fill(depth,INF);
depth[0] = 0;
depth[1] = 1;
Queue<Integer> q = new LinkedList<>();
q.add(1);
sum[1][val[1]]++;
while(!q.isEmpty()){
int t = q.poll();
for(int i = h[t];i!=-1;i = ne[i]){
int j = e[i];
if(depth[j]>depth[t]+1){
depth[j] = depth[t]+1;
q.add(j);
fa[j][0] = t;
sum[j][0] = sum[t][0];
sum[j][1] = sum[t][1];
sum[j][val[j]]++;
for(int k = 1;k<=19;k++)
fa[j][k] = fa[fa[j][k-1]][k-1];
}
}
}
}
public static void add(int a,int b){
e[idx] = b;
ne[idx] = h[a];
h[a] = idx++;
}
public static int nextRead()throws IOException{
sc.nextToken();
return (int)sc.nval;
}
}
分析:lca+树上差分
“点”树上差分代码:
anc = lca(a,b);
sum[a]++;
sum[b]++;
sum[anc]--;
sum[fa[anc][0]]--;
松鼠会从初始点走向下一个点,然后再从该点走向下一个点,由于只进出房间一次,所以除了第一次以外
其他情况都会被多算一次,对于最后一个点,因为是餐厅不需要糖果,所以也会被多算一次,
所以在计算完前缀和时要减去多加的部分
import java.io.*;
import java.util.*;
public class Main{
static int N = 300010,M = 700010;
static int[] h = new int[N],e = new int[M],ne = new int[M];
static int[] depth = new int[N],val = new int[N],
ans = new int[N];
static int[][] fa = new int[N][20];
static int n,m,INF = 0x3f3f3f3f,idx = 0;
static BufferedReader br
= new BufferedReader(new InputStreamReader(System.in));
static BufferedWriter bw
= new BufferedWriter(new OutputStreamWriter(System.out));
static StreamTokenizer sc = new StreamTokenizer(br);
static int[] seq = new int[N];
public static void main(String[] args)throws IOException{
n = nextRead();
Arrays.fill(h,-1);
for(int i = 1;i<=n;i++)
seq[i] = nextRead();
for(int i = 0;i<n-1;i++){
int a = nextRead();
int b = nextRead();
add(a,b);
add(b,a);
}
bfs();
//seq 1 4 5 3 2
for(int i = 1;i<n;i++){
int a = seq[i],b = seq[i+1];
int anc = lca(a,b);
// System.out.println(a+" "+b+" "+anc);
val[a]++;
val[b]++;
val[anc]--;
val[fa[anc][0]]--;
}
dfs(1,0);
for(int i = 2;i<=n;i++)
val[seq[i]]--;
for(int i = 1;i<=n;i++)
bw.write(val[i]+"\n");
bw.flush();
bw.close();
br.close();
}
public static int lca(int a,int b){
if(depth[a]<depth[b]){
int c = a;
a = b;
b = c;
}
for(int k = 19;k>=0;k--)
if(depth[fa[a][k]]>=depth[b])
a = fa[a][k];
if(a == b) return a;
for(int k = 19;k>=0;k--)
if(fa[a][k]!=fa[b][k]){
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
public static void bfs(){
Arrays.fill(depth,INF);
depth[0] = 0;
depth[1] = 1;
Queue<Integer> q = new LinkedList<>();
q.add(1);
while(!q.isEmpty()){
int t = q.poll();
for(int i = h[t];i!=-1;i = ne[i]){
int j = e[i];
if(depth[j]>depth[t]+1){
depth[j] = depth[t]+1;
q.add(j);
fa[j][0] = t;
for(int k = 1;k<=19;k++)
fa[j][k] = fa[fa[j][k-1]][k-1];
}
}
}
}
public static void dfs(int u,int fa){
for(int i = h[u];i!=-1;i = ne[i]){
int j = e[i];
if(j == fa) continue;
dfs(j,u);
val[u]+=val[j];
}
}
public static int nextRead()throws IOException{
sc.nextToken();
return (int)sc.nval;
}
public static void add(int a,int b){
e[idx] = b;
ne[idx] = h[a];
h[a] = idx++;
}
}
P6869 [COCI2019-2020#5] Putovanje
分析:lca+树上差分
“边”树上差分代码:
anc = lca(a,b);
sum[a]++;
sum[b]++;
sum[anc]-=2;
做完lca之后再做一次dfs计算前缀和,对于每条边的贡献值val = min(sum[i]*ci1,ci2);
由于java洛谷直接dfs会爆栈,所以下面使用的是手写栈的写法
import java.io.*;
import java.util.*;
public class Main{
static int N = 200010,M = 2*N;
static int[] h = new int[N],e = new int[M],ne = new int[M],
c1 = new int[M],c2 = new int[M],faEdge = new int[N];
static int[] depth = new int[N],sum = new int[N];
static int[][] fa = new int[N][20];
static BufferedReader br
= new BufferedReader(new InputStreamReader(System.in));
static StreamTokenizer sc = new StreamTokenizer(br);
static int n,m,idx = 0,INF = 0x3f3f3f3f;
static long ans = 0;
public static void main(String[] args)throws IOException{
n = nextRead();
Arrays.fill(h,-1);
for(int i = 1;i<=n-1;i++){
int a = nextRead();
int b = nextRead();
int c = nextRead();
int d = nextRead();
add(a,b,c,d);
add(b,a,c,d);
}
bfs();
for(int i = 1;i<n;i++){
int anc = lca(i,i+1);
sum[i]++;
sum[i+1]++;
sum[anc]-=2;
}
dfs(1);
// for(int i = 1;i<=n;i++)
// System.out.println(sum[i]);
System.out.println(ans);
}
public static void dfs(int u){
int[] stk = new int[N];
int[] collect = new int[N];
boolean[] visit = new boolean[N];
int stkTop = 0,cTop = 0;
visit[u] = true;
stk[stkTop++] = u;
while(stkTop!=0){
int t = stk[--stkTop];
collect[cTop++] = t;
for(int i = h[t];i!=-1;i = ne[i]){
int j = e[i];
if(!visit[j]){
visit[j] = true;
stk[stkTop++] = j;
}
}
}
while(cTop!=0){
int t = collect[--cTop];
int edge = faEdge[t];
long a1 = (long)sum[t]*c1[edge];
long a2 = c2[edge];
ans += Math.min(a1,a2);
sum[fa[t][0]]+=sum[t];
}
}
public static int lca(int a,int b){
if(depth[a]<depth[b]){
int c = a;
a = b;
b = c;
}
for(int k = 19;k>=0;k--)
if(depth[fa[a][k]]>=depth[b])
a = fa[a][k];
if(a == b) return a;
for(int k = 19;k>=0;k--)
if(fa[a][k]!=fa[b][k]){
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
public static void bfs(){
Arrays.fill(depth,INF);
depth[0] = 0;
depth[1] = 1;
Queue<Integer> q = new LinkedList<>();
q.add(1);
while(!q.isEmpty()){
int t = q.poll();
for(int i = h[t];i!=-1;i = ne[i]){
int j = e[i];
if(depth[j]>depth[t]+1){
depth[j] = depth[t]+1;
q.add(j);
fa[j][0] = t;
faEdge[j] = i;
for(int k = 1;k<=19;k++)
fa[j][k] = fa[fa[j][k-1]][k-1];
}
}
}
}
public static void add(int a,int b,int c,int d){
e[idx] = b;
ne[idx] = h[a];
c1[idx] = c;
c2[idx] = d;
h[a] = idx++;
}
public static int nextRead()throws IOException{
sc.nextToken();
return (int)sc.nval;
}
}
同一份代码Java被卡力,还是用用远方的c++吧家人们
模板题,lca+哈希表离散化
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <string>
#include <unordered_map>
#include <queue>
using namespace std;
const int N = 10010,M = 2*N;
int h[N],e[M],ne[M];
int depth[N],seq[N];
int fa[N][14];
int idx = 0,INF = 0x3f3f3f3f,n,m;
unordered_map<int,int> map;
void add(int a,int b){
e[idx] = b;
ne[idx] = h[a];
h[a] = idx++;
}
int lca(int a,int b){
if(depth[a]<depth[b]){
int c = a;
a = b;
b = c;
}
for(int k = 13;k>=0;k--)
if(depth[fa[a][k]]>=depth[b])
a = fa[a][k];
if(a == b) return a;
for(int k = 13;k>=0;k--)
if(fa[a][k]!=fa[b][k]){
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
void bfs(int root){
memset(depth,INF,sizeof depth);
depth[0] = 0;
depth[root] = 1;
queue<int> q;
q.push(root);
while(q.size()>0){
int t = q.front();
q.pop();
for(int i = h[t];i!=-1;i = ne[i]){
int j = e[i];
if(depth[j]>depth[t]+1){
depth[j] = depth[t]+1;
q.push(j);
fa[j][0] = t;
for(int k = 1;k<=13;k++){
fa[j][k] = fa[fa[j][k-1]][k-1];
}
}
}
}
}
void build(int l,int r){
if(l>=r) return;
int root = l;
//左侧
if(seq[l+1]<seq[root]){
add(root,l+1);
add(l+1,root);
}
int idx = 0;
bool f = false;
for(idx = l+1;idx<=r;idx++){
if(seq[idx]>=seq[root]) {
f = true;
break;
}
}
if(l+1<idx-1) build(l+1,idx-1);
if(f){
add(root,idx);
add(idx,root);
build(idx,r);
}
}
int main(){
scanf("%d%d",&m,&n);
memset(h,-1,sizeof h);
for(int i = 1;i<=n;i++){
int a;
scanf("%d",&a);
seq[i] = a;
map.insert({seq[i],i});
}
build(1,n);
bfs(1);
for(int i = 1;i<=m;i++){
int a,b;
scanf("%d%d",&a,&b);
int sta = !map.count(a);
int stb = !map.count(b);
if(sta&&stb){
printf("ERROR: %d and %d are not found.\n",a,b);
continue;
}
if(sta){
printf("ERROR: %d is not found.\n",a);
continue;
}
if(stb){
printf("ERROR: %d is not found.\n",b);
continue;
}
int anc = seq[lca(map[a],map[b])];
if(anc == a)
printf("%d is an ancestor of %d.\n",a,b);
else if(anc == b)
printf("%d is an ancestor of %d.\n",b,a);
else
printf("LCA of %d and %d is %d.\n",a,b,anc);
}
return 0;
}
java代码
import java.io.*;
import java.util.*;
public class Main{
static int N = 10010,M = 2*N;
static int[] h = new int[N],e = new int[M],ne = new int[M];
static int[] depth = new int[N],seq = new int[N];
static int[][] fa = new int[N][14];
static int idx = 0,INF = 0x3f3f3f3f,n,m;
public static void main(String[] args){
Scanner sc = new Scanner(System.in);
m = sc.nextInt();
n = sc.nextInt();
Arrays.fill(h,-1);
HashSet<Integer> set = new HashSet<>();
for(int i = 1;i<=n;i++) {
seq[i] = sc.nextInt();
set.add(seq[i]);
}
build(1,n);
bfs(seq[1]);
for(int i = 1;i<=m;i++){
int a = sc.nextInt();
int b = sc.nextInt();
if(!set.contains(a)&&!set.contains(b)){
System.out.println("ERROR: "+a+" and "+b+
" are not found.");
continue;
}
if(!set.contains(a)){
System.out.println("ERROR: "+a+" is not found.");
continue;
}
if(!set.contains(b)){
System.out.println("ERROR: "+b+" is not found.");
continue;
}
int anc = lca(a,b);
if(anc == a)
System.out.println(a+" is an ancestor of "+b+".");
else if(anc == b)
System.out.println(b+" is an ancestor of "+a+".");
else
System.out.println("LCA of "+a+" and "+b+" is "+anc+".");
}
}
public static int lca(int a,int b){
if(depth[a]<depth[b]){
int c = a;
a = b;
b = c;
}
for(int k = 13;k>=0;k--)
if(depth[fa[a][k]]>=depth[b])
a = fa[a][k];
if(a == b) return a;
for(int k = 13;k>=0;k--)
if(fa[a][k]!=fa[b][k]){
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
public static void bfs(int root){
Arrays.fill(depth,INF);
depth[0] = 0;
depth[root] = 1;
Queue<Integer> q = new LinkedList<>();
q.add(root);
while(!q.isEmpty()){
int t = q.poll();
for(int i = h[t];i!=-1;i = ne[i]){
int j = e[i];
if(depth[j]>depth[t]+1){
depth[j] = depth[t]+1;
q.add(j);
fa[j][0] = t;
for(int k = 1;k<=13;k++){
fa[j][k] = fa[fa[j][k-1]][k-1];
}
}
}
}
}
public static void build(int l,int r){
if(l>=r) return;
int root = seq[l];
//左侧
if(seq[l+1]<root){
add(root,seq[l+1]);
add(seq[l+1],root);
}
int idx = 0;
boolean f = false;
for(idx = l+1;idx<=r;idx++){
if(seq[idx]>=root) {
f = true;
break;
}
}
if(l+1<idx-1) build(l+1,idx-1);
if(f){
add(root,seq[idx]);
add(seq[idx],root);
build(idx,r);
}
}
public static void add(int a,int b){
e[idx] = b;
ne[idx] = h[a];
h[a] = idx++;
}
}
tarjan-lca
import java.io.*;
import java.util.*;
public class Main{
static int N = 10010,M = 2*N;
static int[] h = new int[N],e = new int[M],ne = new int[M];
static int[] st = new int[N],p = new int[N],
seq = new int[N],seqx = new int[N];
static List<int[]>[] list = new List[N];
static String[] ans = new String[1010];
static int idx = 0,INF = 0x3f3f3f3f,n,m;
static BufferedReader br
= new BufferedReader(new InputStreamReader(System.in));
static StreamTokenizer sc = new StreamTokenizer(br);
public static void main(String[] args)throws IOException{
BufferedWriter bw
= new BufferedWriter(new OutputStreamWriter(System.out));
m = nextRead();
n = nextRead();
Arrays.fill(h,-1);
for(int i = 0;i<=n;i++) p[i] = i;
HashMap<Integer,Integer> map = new HashMap<>();
for(int i = 1;i<=n;i++) {
seq[i] = nextRead();
}
for(int i = 1;i<=n;i++){
seqx[i] = nextRead();
map.put(seqx[i],i);
}
for(int i = 1;i<=m;i++){
int a = nextRead();
int b = nextRead();
boolean sta = !map.containsKey(a);
boolean stb = !map.containsKey(b);
if(sta&&stb){
ans[i] = "ERROR: "+a+" and "+b+" are not found."+"\n";
continue;
}
if(sta){
ans[i] = "ERROR: "+a+" is not found."+"\n";
continue;
}
if(stb){
ans[i] = "ERROR: "+b+" is not found."+"\n";
continue;
}
if(a == b){
ans[i] = a+" is an ancestor of "+a+"."+"\n";
continue;
}
if(list[map.get(a)] == null)
list[map.get(a)] = new ArrayList<>();
if(list[map.get(b)] == null)
list[map.get(b)] = new ArrayList<>();
list[map.get(a)].add(new int[]{map.get(b),i,1});
list[map.get(b)].add(new int[]{map.get(a),i,-1});
}
build(1,n,1,n);
tarjan(1);
for(int i = 1;i<=m;i++)
bw.write(ans[i]);
bw.flush();
}
public static int nextRead()throws IOException{
sc.nextToken();
return (int)sc.nval;
}
public static void tarjan(int u){
st[u] = 1;
//System.out.println(u);
for(int i = h[u];i!=-1;i = ne[i]){
int j = e[i];
if(st[j] == 0){
tarjan(j);
p[j] = u;
}
}
if(list[u] != null){
for(var ne:list[u]){
int y = ne[0],id = ne[1],state = ne[2];
if(st[y] == 2){
int anc = find(y);
if(state == -1){
int c = u;
u = y;
y = c;
}
if(anc == u)
ans[id] = seqx[u]+" is an ancestor of "+seqx[y]+"."+"\n";
else if(anc == y)
ans[id] = seqx[y]+" is an ancestor of "+seqx[u]+"."+"\n";
else
ans[id] =
"LCA of "+seqx[u]+" and "+seqx[y]+" is "+seqx[anc]+"."+"\n";
if(state == -1){
int c = u;
u = y;
y = c;
}
}
}
}
st[u] = 2;
}
public static int find(int x){
if(p[x]!=x) p[x] = find(p[x]);
return p[x];
}
//mid seq pre seqx
public static void build(int ml,int mr,int pl,int pr){
int root = pl;
int idx = ml;
for(int i = ml;i<=mr;i++)
if(seq[i] == seqx[pl]){
idx = i;
break;
}
if(idx-ml>0){
add(root,pl+1);
add(pl+1,root);
// System.out.println("lr "+ml+" "+mr+" "+pl+" "+pr);
// System.out.println(seqx[root]+" "+seqx[pl+1]);
build(ml,idx-1,pl+1,pl+1+idx-ml);
}
int r = pl+1+idx-ml;
if(idx<mr){
add(root,r);
add(r,root);
// System.out.println("lr "+ml+" "+mr+" "+pl+" "+pr);
// System.out.println(seqx[root]+" "+seqx[r]);
build(idx+1,mr,r,pr);
}
}
public static void add(int a,int b){
e[idx] = b;
ne[idx] = h[a];
h[a] = idx++;
}
}
厉害