有一个 n+2 行 m 列的网格。
每天,除了第 0 行和第 n+1 行,其余的行最左边和最右边的格子会有 p 的概率被摧毁。
求第 k 天网格依然联通的概率。
比较幽默的题。首先如果墙没倒,一定是每一层还剩一个区间,并且相邻有交。直接做加上一点前缀和就是 O(nm2)。
然后我们发现,只需要算这个前缀和即可,根本不需要把每一项算出来。就变成了 O(nm)。
杜老师是这样的。
首先容易想到朴素 DP:设 fi,l,r 表示第 i 行,剩下 l,r,且 [0,i] 行都联通的概率。
由于只会摧毁左右端点,所以最后剩下的是一段区间。考虑求剩下 [l,r] 的概率。
由于左右是独立的,设 Pi 表示恰好摧毁 i 个的概率,那么 Pi=Ck,ipi×(1−p)k−i。
则剩下 [l,r] 的概率就是 Pl−1×Pm−r,由于最多只能摧毁 k 个,要特判端点范围。
则有:
fi,l,r=∑[l′,r′]∩[l,r]≠∅fi−1,l′,r′×Pl−1×Pm−r
这样需要枚举 i,l,r,l′,r′,显然是不优的。
考虑简化 O(nm2) 的状态:在转移的时候再枚举 l,状态中只记录 i,r。
设 gi,r 表示 [0,i] 这些行联通,且第 i 行的右端点为 r 的概率。
反之,设 hi,l 表示第 i 行左端点为 l 的概率,但由于网格是左右对称的,所以 gi,r=hi,m−r+1。
这样并不好直接算交集的贡献,那么正难则反,考虑容斥,用总数减去不交的贡献。
- r′<l:此时 [l′,r′] 位于 [l,r] 左边,贡献为 l−1∑j=1gi−1,j。
- l′>r:此时 [l′,r′] 位于 [l,r] 右边,贡献为 m∑j=r+1hi−1,j=m−r∑j=1gi−1,j。
因此有转移方程式:
gi,r=r∑l=1Pl−1×Pm−r×(m∑j=1gi−1,j−l−1∑j=1gi−1,j−m−r∑j=1gi−1,j)
后面三个都是前缀和形式,考虑设 si,r=r∑j=1gi,r。
想起了一位故人。
gi,r=r∑l=1Pl−1×Pm−r×(si−1,m−si−1,l−1−si−1,m−r)
至此复杂度为 O(nm2),可以成功 TLE 本题。
考虑把式子改的好看一些:
gi,r=Pm−r[(si−1,m−si−1,m−r)r∑l=1Pl−1−r∑l=1Pl−1×si−1,l−1]
这个式子还是很小清新的。
记录一下 Pl 的前缀和以及 pl×si−1,l 的前缀和就行了,可以做到 O(nm)。
#include <bits/stdc++.h>
using namespace std;
const int N = 1505, M = 1e5 + 5, mod = 1e9 + 7;
int n, m, A, B, k, p, p_1;
int qmi(int a, int k) {
int res = 1;
while (k) {
if (k & 1) res = res * 1ll * a % mod;
a = a * 1ll * a % mod, k >>= 1;
}
return res;
}
int fac[M], inv[M];
void initC() {
fac[0] = 1; for (int i = 1; i <= 100000; i++) fac[i] = (fac[i - 1] * 1ll * i) % mod;
inv[100000] = qmi(fac[100000], mod - 2);
for (int i = 99999; i >= 0; i--) inv[i] = inv[i + 1] * 1ll * (i + 1) % mod;
}
int C(int n, int m) { return fac[n] * 1ll * inv[m] % mod * 1ll * inv[n - m] % mod; }
long long P[N], sp[N];
void initP() {
p = ( A * 1ll * qmi(B, mod - 2) ) % mod, p_1 = (mod + 1 - p) % mod;
for (int i = 0; i <= min(m, k); i++)
P[i] = C(k, i) * 1ll * qmi(p, i) % mod * 1ll * qmi(p_1, k - i) % mod;
sp[0] = P[0]; for (int i = 1; i <= m; i++) sp[i] = (sp[i - 1] + P[i]) % mod;
}
long long g[N][N], s[N][N], ssp[N][N];
int main() {
initC();
scanf("%d%d%d%d%d", &n, &m, &A, &B, &k);
initP();
/*
g[0][m] = 1, s[0][m] = 1;
for (int i = 1; i <= n; i++) {
for (int r = 1; r <= m; r++) {
long long sum = 0;
for (int l = 1; l <= r; l++) {
long long res = P[l - 1] * P[m - r] % mod;
res = res * ( (s[i - 1][m] - s[i - 1][l - 1] - s[i - 1][m - r]) % mod + mod) % mod;
(sum += res % mod) %= mod;
}
g[i][r] = sum, s[i][r] = (s[i][r - 1] + sum) % mod;
}
}
*/
g[0][m] = 1, s[0][m] = 1, ssp[0][m] = P[m];
for (int i = 1; i <= n; i++) {
for (int r = 1; r <= m; r++) {
g[i][r] = ( (s[i - 1][m] - s[i - 1][m - r] + mod) * sp[r - 1] ) % mod;
g[i][r] = (g[i][r] - ssp[i - 1][r - 1] + mod) % mod;
(g[i][r] *= P[m - r]) %= mod;
}
for (int r = 1; r <= m; r++)
s[i][r] = (s[i][r - 1] + g[i][r]) % mod, ssp[i][r] = (ssp[i][r - 1] + s[i][r] * 1ll * P[r] % mod) % mod;
}
printf("%d\n", s[n][m]);
return 0;
}