【组合数学】专题笔记目录
现在要从 (a,b,c) 走向 (0,0,0),假设中间有一个状态是 (x,y,z)。
那么走到 (x,y,z) 需要 a+b+c−x−y−z 步,由于每一步可以在三个维度中的任意一个选择,所以方案数是 3a+b+c−x−y−z。
但是这显然是不满足走到 (x,y,z) 的要求,因为每一个维度有限定步数,必须恰好走 a−x,b−y,c−z 步。
那么考虑按位满足条件。第一维满足条件的方案数即 Ca−xa+b+c−x−y−z,第二维满足条件的方案数就是 Cb−yb+c−y−z(因为第一维已经被选中),第三维就是剩下的步数。
因此根据乘法原理,可以求出走到 (x,y,z) 概率为:
Ca−xa+b+c−x−y−zCb−yb+c−y−z3a+b+c−x−y−z
我们接下来假设它只撞第三维,因为撞到前两维的边界是和这种情况同理的。
考虑最朴素的做法:枚举 x,y,并根据概率求出期望。观察到这和 z 这一项已经没有关系了,因为 z=0。即:
Ans=a∑x=0b∑y=0Cca+b+c−x−yCa−xa+b−x−y×(x+y)k3a+b+c−x−y
上述式子的意义是,先选 c 步使得第三维撞墙,剩下的一二维再选,最后得到概率再乘上贡献得到期望。
然后我们切换角度,枚举 p=x+y,可以把式子变成这样(中间有两三步移项等操作):
a+b∑p=0Cca+b+c−ppk3a+b+c−pp∑x=0Ca−xa+b−p
设 f(p)=p∑x=0Ca−xa+b−p,则式子转化为:
a+b∑p=0Cca+b+c−ppk3a+b+c−pf(p)
观察到如果我们已知 f,可以 O(A+B+C) 地解决掉这个问题,那么如何快速处理 f 呢?
根据 Cmn=Cm−1n−1+Cmn−1 把组合数拆出来:f(p)=p∑x=0Ca−xa+b−p=p∑x=0Ca−x−1a+b−p−1+p∑x=0Ca−xa+b−p−1
由于 f(p+1)=p+1∑x=0Ca−xa+b−p−1,代入得到递推式!
f(p)=2f(p+1)−Caa+b−p−1−Cba+b−p−1
于是可以线性复杂度地处理出 f 了,问题便迎刃而解。
至此,只要预处理 f 便可以通过这题的 50 部分分。
是哪里被卡了呢?复杂度是 O(A+B+C) 显然能过啊!
预处理 3 的次幂的逆元,和 pk,即可降低大常数。
#include <bits/stdc++.h>
using namespace std;
const int N = 1.5e7 + 15, mod = 998244353;
int a, b, c, k;
int qmi(int a, int k) {
int res = 1;
while (k) {
if (k & 1) res = (res * 1ll * a) % mod;
a = (a * 1ll * a) % mod;
k >>= 1;
}
return res;
}
int fac[N], inv[N];
int inv3[N], pw[N];
void init(int lim) {
inv3[a + b + c] = qmi(qmi(3, a + b + c), mod - 2);
for (register int i = a + b + c - 1; i >= 1; i--) inv3[i] = (inv3[i + 1] * 3ll) % mod;
for (int i = 1; i <= a + b + c; i++) pw[i] = qmi(i, k);
fac[0] = 1;
for (register int i = 1; i <= lim; i++) fac[i] = (fac[i - 1] * 1ll * i) % mod;
inv[lim] = qmi(fac[lim], mod - 2);
for (register int i = lim - 1; i >= 0; i--) inv[i] = (inv[i + 1] * 1ll * (i + 1)) % mod;
}
inline int C(int n, int m) {
if (m < 0 || m > n) return 0;
return (fac[n] * 1ll * inv[m] % mod * 1ll * inv[n - m]) % mod;
}
int f[N];
long long ans = 0;
void solve(int a, int b, int c) {
f[0] = 1;
for (register int p = 1; p <= a + b; p++) //改为枚举a+b-p
f[p] = (f[p - 1] * 2ll % mod - C(p - 1, a) + mod % mod - C(p - 1, b) + mod) % mod;
for (register int p = 0; p <= a + b; p++)
(ans += ((long long)C(a + b + c - p, c) * pw[p] % mod * inv3[a + b + c - p] % mod * f[a + b - p] % mod) + mod) %= mod;
}
int main() {
scanf("%d%d%d%d", &a, &b, &c, &k);
init(a + b + c);
solve(a, b, c), solve(b, c, a), solve(c, a, b);
printf("%lld\n", ans * 1ll * inv3[1] % mod);
return 0;
}