题目描述
难度分:2100
输入T(≤104)表示T组数据。所有数据的n之和≤2×105,m之和≤2×105。
每组数据输入n(2≤n≤2×105),m(1≤m≤2×105),表示一个n点m边的无向图。
然后输入s和t(s≠t),表示起点和终点。节点编号从1开始。
然后输入m条边,每条边输入x、y,表示有一条无向边连接x和y。
保证图是连通的。保证图中无自环和重边。
设从s到t的最短路长度为d。输出从s到t的路径个数,要求路径的长度至多为d+1。答案模109+7。
输入样例
4
4 4
1 4
1 2
3 4
2 3
2 4
6 8
6 1
1 4
1 6
1 5
1 2
5 6
4 6
6 3
2 6
5 6
1 3
3 5
5 4
3 1
4 2
2 1
1 4
8 18
5 1
2 1
3 1
4 2
5 2
6 5
7 3
8 4
6 4
8 7
1 4
4 7
1 6
6 7
3 8
8 5
4 5
4 3
8 2
输出样例
2
4
1
11
算法
BFS
+前后缀分解
比较容易想到要从起点s和终点t分别进行BFS
,然后枚举一个中间节点i做前后缀分解,组合前面的方案(从s到i的路径)和后面的方案(从i到t的路径)。但是有一些细节要处理一下,在做BFS
的过程中我们其实就可以通过DP
来计数。
动态规划
状态定义
f[i]表示从起点到i是最短路的方案数。
状态转移
当遍历到某个节点u时,展开它的所有邻居v:
- 如果dist[v]<dist[u]+1,更新最短路,扩展队列。
- 如果dist[v]=dist[u]+1,说明此时的最短路是不变的,直接把起点到u的方案数f[u]累加到f[v]上,也就是说有多少条到u的最短路就有多少条到v的最短路。
但我们做完BFS
之后还只是求得了最短路的方案数,如果以s为起点做BFS
得到的DP
数组为fs,把答案初始化为fs[t],这就是从s到t距离为最短路d时的所有方案数。
然后我们直接考虑距离为d+1时的情况,假设以s为起点得到的dist数组为ds,DP
数组为fs,以t为起点得到的dist数组为dt,DP
数组为ft。枚举一个中间节点i,我们希望这时候从s经过i到t时的距离为d+1,说明要多经过一条边到j,然后j继续走最短路,可以枚举i的邻居j。
此时需要满足以下两个条件:
- ds[i]+dt[j]=d,这样算上i到j这条边就能得到d+1这个距离。
- ds[i]=ds[j],这就说明从s到j的最短路根本就没必要经过i,i到j这条边根本就是强行加上去的。
把所有满足这两个条件的fs[i]×ft[j]累加起来就是距离为d+1时的路径数,算上前面距离为d时的路径数就是最终答案。
复杂度分析
时间复杂度
跑两遍BFS
求s和t到各点的最短路,在遍历的过程中利用DP
统计,时间复杂度为O(n+m)。然后遍历中间点i及其邻居j,在最差情况下其实就是遍历所有节点和边,时间复杂度仍然是O(n+m)。因此,整个算法的时间复杂度为O(n+m)。
空间复杂度
图的邻接表空间消耗为O(n+m),BFS
求最短路时的dist数组和DP
数组f的空间都是线性的,为O(n)。因此,整个算法的额外空间复杂度为O(n+m)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 200010, INF = 0x3f3f3f3f, MOD = 1e9 + 7;
int T, n, m;
vector<int> graph[N];
pair<vector<int>, vector<int>> bfs(int src) {
vector<int> dist(n + 1, INF);
queue<int> q;
q.push(src);
dist[src] = 0;
vector<int> f(n + 1);
f[src] = 1;
while(!q.empty()) {
int cur = q.front();
q.pop();
for(int nxt: graph[cur]) {
if(dist[nxt] == INF) {
dist[nxt] = dist[cur] + 1;
q.push(nxt);
}
if(dist[nxt] == dist[cur] + 1) {
f[nxt] = (f[nxt] + f[cur]) % MOD;
}
}
}
return make_pair(dist, f);
}
void solve() {
scanf("%d%d", &n, &m);
int s, t;
scanf("%d%d", &s, &t);
for(int i = 1; i <= n; i++) {
graph[i].clear();
}
for(int i = 1; i <= m; i++) {
int u, v;
scanf("%d%d", &u, &v);
graph[u].push_back(v);
graph[v].push_back(u);
}
auto [ds, fs] = bfs(s);
auto [dt, ft] = bfs(t);
int d = ds[t], ans = fs[t];
for(int i = 1; i <= n; i++) {
for(int j: graph[i]) {
if(ds[i] + dt[j] == d && ds[i] == ds[j]) {
ans = (ans + (LL)fs[i] * ft[j] % MOD) % MOD;
}
}
}
printf("%d\n", ans);
}
int main() {
scanf("%d", &T);
while(T--) {
solve();
}
return 0;
}