$$【组合数学】专题笔记目录$$
现在要从 $(a,b,c)$ 走向 $(0,0,0)$,假设中间有一个状态是 $(x,y,z)$。
那么走到 $(x,y,z)$ 需要 $a+b+c-x-y-z$ 步,由于每一步可以在三个维度中的任意一个选择,所以方案数是 $3^{a+b+c-x-y-z}$。
但是这显然是不满足走到 $(x,y,z)$ 的要求,因为每一个维度有限定步数,必须恰好走 $a-x,b-y,c-z$ 步。
那么考虑按位满足条件。第一维满足条件的方案数即 $C_{a+b+c-x-y-z}^{a-x}$,第二维满足条件的方案数就是 $C_{b+c-y-z}^{b-y}$(因为第一维已经被选中),第三维就是剩下的步数。
因此根据乘法原理,可以求出走到 $(x,y,z)$ 概率为:
$$\frac{ C_{a+b+c-x-y-z}^{a-x} C_{b+c-y-z}^{b-y} }{ 3^{a+b+c-x-y-z} }$$
我们接下来假设它只撞第三维,因为撞到前两维的边界是和这种情况同理的。
考虑最朴素的做法:枚举 $x,y$,并根据概率求出期望。观察到这和 $z$ 这一项已经没有关系了,因为 $z=0$。即:
$$Ans = \sum\limits_{x=0}^{a} \sum\limits_{y=0}^{b} \frac{ C_{a+b+c-x-y}^{c} C_{a+b-x-y}^{a-x} \times (x + y)^k }{ 3^{a+b+c-x-y} }$$
上述式子的意义是,先选 $c$ 步使得第三维撞墙,剩下的一二维再选,最后得到概率再乘上贡献得到期望。
然后我们切换角度,枚举 $p=x+y$,可以把式子变成这样(中间有两三步移项等操作):
$$\sum\limits_{p=0}^{a+b} \frac{C_{a+b+c-p}^{c} p^k}{3^{a+b+c-p}} \sum\limits_{x=0}^{p} C_{a+b-p}^{a-x}$$
设 $f(p)=\sum\limits_{x=0}^{p} C_{a+b-p}^{a-x}$,则式子转化为:
$$\sum\limits_{p=0}^{a+b} \frac{C_{a+b+c-p}^{c} p^k}{3^{a+b+c-p}} f(p)$$
观察到如果我们已知 $f$,可以 $O(A+B+C)$ 地解决掉这个问题,那么如何快速处理 $f$ 呢?
根据 $C_{n}^{m}=C_{n-1}^{m-1} + C_{n-1}^{m}$ 把组合数拆出来:$f(p)=\sum\limits_{x=0}^{p} C_{a+b-p}^{a-x} = \sum\limits_{x=0}^{p} C_{a+b-p-1}^{a-x-1} + \sum\limits_{x=0}^{p} C_{a+b-p-1}^{a-x}$
由于 $f(p+1)=\sum\limits_{x=0}^{p+1} C_{a+b-p-1}^{a-x}$,代入得到递推式!
$$f(p) = 2f(p+1) - C_{a+b-p-1}^{a} - C_{a+b-p-1}^{b}$$
于是可以线性复杂度地处理出 $f$ 了,问题便迎刃而解。
至此,只要预处理 $f$ 便可以通过这题的 50 部分分。
是哪里被卡了呢?复杂度是 $O(A+B+C)$ 显然能过啊!
预处理 $3$ 的次幂的逆元,和 $p^k$,即可降低大常数。
#include <bits/stdc++.h>
using namespace std;
const int N = 1.5e7 + 15, mod = 998244353;
int a, b, c, k;
int qmi(int a, int k) {
int res = 1;
while (k) {
if (k & 1) res = (res * 1ll * a) % mod;
a = (a * 1ll * a) % mod;
k >>= 1;
}
return res;
}
int fac[N], inv[N];
int inv3[N], pw[N];
void init(int lim) {
inv3[a + b + c] = qmi(qmi(3, a + b + c), mod - 2);
for (register int i = a + b + c - 1; i >= 1; i--) inv3[i] = (inv3[i + 1] * 3ll) % mod;
for (int i = 1; i <= a + b + c; i++) pw[i] = qmi(i, k);
fac[0] = 1;
for (register int i = 1; i <= lim; i++) fac[i] = (fac[i - 1] * 1ll * i) % mod;
inv[lim] = qmi(fac[lim], mod - 2);
for (register int i = lim - 1; i >= 0; i--) inv[i] = (inv[i + 1] * 1ll * (i + 1)) % mod;
}
inline int C(int n, int m) {
if (m < 0 || m > n) return 0;
return (fac[n] * 1ll * inv[m] % mod * 1ll * inv[n - m]) % mod;
}
int f[N];
long long ans = 0;
void solve(int a, int b, int c) {
f[0] = 1;
for (register int p = 1; p <= a + b; p++) //改为枚举a+b-p
f[p] = (f[p - 1] * 2ll % mod - C(p - 1, a) + mod % mod - C(p - 1, b) + mod) % mod;
for (register int p = 0; p <= a + b; p++)
(ans += ((long long)C(a + b + c - p, c) * pw[p] % mod * inv3[a + b + c - p] % mod * f[a + b - p] % mod) + mod) %= mod;
}
int main() {
scanf("%d%d%d%d", &a, &b, &c, &k);
init(a + b + c);
solve(a, b, c), solve(b, c, a), solve(c, a, b);
printf("%lld\n", ans * 1ll * inv3[1] % mod);
return 0;
}