P3398 仓鼠找 sugar
题目描述
小仓鼠的和他的基(mei)友(zi)sugar住在地下洞穴中,每个节点的编号为1~n。地下洞穴是一个树形结构。这一天小仓鼠打算从从他的卧室(a)到餐厅(b),而他的基友同时要从他的卧室(c)到图书馆(d)。他们都会走最短路径。现在小仓鼠希望知道,有没有可能在某个地方,可以碰到他的基友?
输入格式
第一行两个正整数n和q,表示这棵树节点的个数和询问的个数。
接下来n-1行,每行两个正整数u和v,表示节点u到节点v之间有一条边。
接下来q行,每行四个正整数a、b、c和d,表示节点编号,也就是一次询问,其意义如上。
100%的数据 n<=100000,q<=100000
分析:
由于q的级别为10^5,n的级别为10^5,如果lca+前缀和,复杂度会达到10^10
所以肯定不能爆搜,不如看看有什么性质
题意抽象出来就是 在一棵树上,给两组点,判读这两组点的最短路间的点是否有交集
下面只分析半边的情况,因为另一半是完全对称的
对于两组点的最短路有交集,那么这个点一定在a,b路径上,下面进行分类讨论,设ab祖先为lcaA,cd祖先为lcaB
首先证明一个结论,对于一棵树,若c在ab路径上,一定有dis[a,c]+dis[c,b] = dis[a,b]
反证法:假设c不在ab路径上但有dis[a,c]+dis[c,b] = dis[a,b]
那么bc肯定存在一条路(如图)
由于树上两点的最短路是唯一的,
假设c在a的下面 那么dis[c,b]一定大于dis[a,b]
假设c在b的上面 那么dis[a,c]一定大于dis[a,b]
综上,若c不在ab上面,结论矛盾
所以若c在ab路径上,一定有dis[a,c]+dis[c,b] = dis[a,b]
有了这条定理,要快速判断某个点是否在某两个点的最短路上就很方便
只需要判断dis[a,c]+dis[c,lca(a,b)] == dis[a,lca(a,b)]即可
有两组点,需要判断六次(包括祖先)
代码
import java.io.*;
import java.util.*;
public class Main{
static int N = 100010,M = 2*N;
static int[] h = new int[N],e = new int[M],ne = new int[M];
static int[] depth = new int[N];
static int[][] fa = new int[N][20];
static int n,m,idx = 0,INF = 0x3f3f3f3f;
static BufferedReader br
= new BufferedReader(new InputStreamReader(System.in));
static StreamTokenizer sc = new StreamTokenizer(br);
static BufferedWriter bw
= new BufferedWriter(new OutputStreamWriter(System.out));
public static void main(String[] args)throws IOException{
n = nextRead();
m = nextRead();
Arrays.fill(h,-1);
for(int i = 1;i<=n-1;i++){
int a = nextRead();
int b = nextRead();
add(a,b);
add(b,a);
}
bfs();
for(int i = 1;i<=m;i++){
int a = nextRead(),b = nextRead();
int c = nextRead(),d = nextRead();
int lcaA = lca(a,b),lcaB = lca(c,d);
boolean success = false;
success |= is_inEdge(a,b,c,lcaA);
success |= is_inEdge(a,b,d,lcaA);
success |= is_inEdge(a,b,lcaB,lcaA);
success |= is_inEdge(c,d,a,lcaB);
success |= is_inEdge(c,d,b,lcaB);
success |= is_inEdge(c,d,lcaA,lcaB);
if(success) bw.write("Y\n");
else bw.write("N\n");
}
bw.flush();
bw.close();
br.close();
}
public static boolean is_inEdge(int a,int b,int c,int lcaA){
if(dis(a,c)+dis(c,lcaA) == dis(a,lcaA)) return true;
if(dis(b,c)+dis(c,lcaA) == dis(b,lcaA)) return true;
return false;
}
public static int dis(int a,int b){
return depth[a]+depth[b]-2*depth[lca(a,b)];
}
public static int lca(int a,int b){
if(depth[a]<depth[b]){
int c = a;
a = b;
b = c;
}
for(int k = 19;k>=0;k--)
if(depth[fa[a][k]]>=depth[b])
a = fa[a][k];
if(a == b) return a;
for(int k = 19;k>=0;k--)
if(fa[a][k]!=fa[b][k]){
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
public static void bfs(){
Arrays.fill(depth,INF);
int[] q = new int[N];
int hh = 0,tt = -1;
depth[0] = 0;
depth[1] = 1;
q[++tt] = 1;
while(tt>=hh){
int t = q[hh++];
for(int i = h[t];i!=-1;i = ne[i]){
int j = e[i];
if(depth[j]>depth[t]+1){
depth[j] = depth[t]+1;
q[++tt] = j;
fa[j][0] = t;
for(int k = 1;k<=19;k++)
fa[j][k] = fa[fa[j][k-1]][k-1];
}
}
}
}
public static void add(int a,int b){
e[idx] = b;
ne[idx] = h[a];
h[a] = idx++;
}
public static int nextRead()throws IOException{
sc.nextToken();
return (int)sc.nval;
}
}
3136.求和
分析:预处理出所有1~50的深度前缀和,y总模板在做的时候计算的是点深度,而边深度=点深度-1
公式:sum[k][depth[a]-1]+sum[k][depth[b]-1]-sum[k][depth[anc]]-sum[k][depth[fa[anc][0]]];
注意取模要用 (x%mod+mod)%mod;
代码:
import java.io.*;
import java.util.*;
public class Main{
static int N = 300010,M = 2*N;
static int[] h = new int[N],e = new int[M],ne = new int[M];
static int[] depth = new int[N];
static long[][] sum = new long[51][N];
static int[][] fa = new int[N][20];
static int idx = 0,n,m,INF = 0x3f3f3f3f,mod = 998244353;
static BufferedReader br
= new BufferedReader(new InputStreamReader(System.in));
static StreamTokenizer sc = new StreamTokenizer(br);
static BufferedWriter bw
= new BufferedWriter(new OutputStreamWriter(System.out));
public static void main(String[] args)throws IOException{
n = nextRead();
Arrays.fill(h,-1);
for(int i = 0;i<=50;i++){
Arrays.fill(sum[i],1);
sum[i][0] = 0;
}
for(int k = 1;k<=50;k++){
for(int i = 1;i<N;i++){
sum[0][i]= mmod(sum[0][i]*i,mod);
sum[k][i] = mmod(sum[0][i]+sum[k][i-1],mod);
}
}
for(int i = 1;i<=n-1;i++){
int a = nextRead();
int b = nextRead();
add(a,b);
add(b,a);
}
bfs();
m = nextRead();
for(int i = 1;i<=m;i++){
int a = nextRead();
int b = nextRead();
int k = nextRead();
int anc = lca(a,b);
long ans = mmod(sum[k][depth[a]-1]+sum[k][depth[b]-1],mod);
ans = mmod(ans-sum[k][depth[anc]-1],mod);
ans = mmod(ans-sum[k][Math.max(depth[fa[anc][0]]-1,0)],mod);
bw.write(ans+"\n");
}
bw.flush();
bw.close();
br.close();
}
public static long mmod(long x,int mod){
return (x%mod+mod)%mod;
}
public static int lca(int a,int b){
if(depth[a]<depth[b]){
int c = a;
a = b;
b = c;
}
for(int k = 19;k>=0;k--)
if(depth[fa[a][k]]>=depth[b])
a = fa[a][k];
if(a == b) return a;
for(int k = 19;k>=0;k--)
if(fa[a][k]!=fa[b][k]){
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
public static void bfs(){
Arrays.fill(depth,INF);
depth[0] = 0;
depth[1] = 1;
Queue<Integer> q = new LinkedList<>();
q.add(1);
while(!q.isEmpty()){
int t = q.poll();
for(int i = h[t];i!=-1;i = ne[i]){
int j = e[i];
if(depth[j]>depth[t]+1){
depth[j] = depth[t]+1;
q.add(j);
fa[j][0] = t;
for(int k = 1;k<=19;k++)
fa[j][k] = fa[fa[j][k-1]][k-1];
}
}
}
}
public static void add(int a,int b){
e[idx] = b;
ne[idx] = h[a];
h[a] = idx++;
}
public static int nextRead()throws IOException{
sc.nextToken();
return (int)sc.nval;
}
}