题意
给定 n,m,p,求 Cmn+mmod,保证 p 为质数。
1 \le n,m,p \le 10^5
分析
看到题目,感觉是水题,直接写了求阶乘,然后求费马小定理求逆元。但是 \text{WA} 了。
我们发现,这题里没有保证 n < p 或 \gcd(n,p) = 1,所以不一定有逆元。如果不能求逆元,我们就要换一种方法。这里要用到 Lucas 定理
Lucas 定理如下。如果 P 为质数。
C_n^m \equiv C_{n \bmod P}^{m \bmod P} \times C_{\left \lfloor \frac{n}{P} \right \rfloor}^{\left \lfloor \frac{m}{P} \right \rfloor} \pmod P
也就是说,我们可以递归求解。因为 n \bmod P < P,所以前一项可以用逆元求解。后一项递归求解。因为每次至少会变成之前的一半,所以时间复杂度是 O(\log n) 的。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define il inline
#define N 100005
il int rd(){
int s = 0, w = 1;
char ch = getchar();
for (;ch < '0' || ch > '9'; ch = getchar()) if (ch == '-') w = -1;
for (;ch >= '0' && ch <= '9'; ch = getchar()) s = ((s << 1) + (s << 3) + ch - '0');
return s * w;
}
int n, m, P, fact[N], inv_fact[N];
il int ksm(int x, int r, int P){
int ans = 1;
for (; r; x = 1ll * x * x % P, r >>= 1) if (r & 1) ans = 1ll * ans * x % P;
return ans;
}
int C(int n, int m, int P){return (n < m ? 0 : 1ll * fact[n] * inv_fact[m] % P * inv_fact[n - m] % P);}
int Lcs(int n, int m, int P){return (n < m ? 0 : !n ? 1 : 1ll * Lcs(n / P, m / P, P) * C(n % P, m % P, P) % P);}
int Main(){
n = rd(), m = rd(), P = rd(), fact[0] = 1;
for (int i = 1; i < P; i++) fact[i] = 1ll * fact[i - 1] * i % P;
inv_fact[P - 1] = ksm(fact[P - 1], P - 2, P);
for (int i = P - 2; i >= 0; i--) inv_fact[i] = 1ll * inv_fact[i + 1] * (i + 1) % P;
printf ("%d\n", Lcs(n + m, m, P));
return 0;
}
int main(){
for (int T = rd(); T--;) Main();
return 0;
}
P4720 【模板】扩展卢卡斯定理/exLucas
这题我们会发现 P 不一定为质数。考虑要求解,可以先质因数分解。
P = p_1^{r_1}p_2^{r_2}p_3^{r_3}\dots p_k^{r_k}
我们只需要求 C_n^m \bmod p^r,然后用 CRT 求解。
但是我们还是不能求逆元,因为 m! 不一定与 p^r 互质。考虑转化一下。
C_n^m \bmod p^r = \frac{\frac{n!}{p^x}}{\frac{m!}{p^y}\frac{(n-m))!}{p^z}} p^{x-y-z} \bmod p^r
这里 x 表示 n! 中 p 的质因子个数,y,z 同理。我们现在只要求 \frac{n!}{p^x} \bmod p^r 即可。
定义 f(n) = \frac{n!}{p^x} \bmod p^r。
我们把 n! 拆开。
n! = 1 \times 2 \times 3 \dots \times n = (p \times 2p \times 3p \dots \times \left \lfloor \frac{n}{p} \right \rfloor p)\times(1\times 2 \times \dots)
左边括号是所有 p 的倍数相乘,右边是剩余的数字相乘。左边提出来 p,得到的是 \left \lfloor n/p \right \rfloor!,主要考虑右边。
\prod_{i=1,\gcd(i,P)=1}^n i \bmod P
很明显,这个是有循环节的,直接提出来,不组成整个节的暴力算。因为我不想写 \text{Latex},所以直接给拆出来的式子了。
f(n) \equiv f(\left \lfloor \frac{n}{p} \right \rfloor)(\prod_{i=0,\gcd(i,P)=1}^{p^r} i) ^ {\left \lfloor \frac{n}{p^r} \right \rfloor}(\prod_{i=p^r\left \lfloor \frac{n}{p^r} \right \rfloor,\gcd(i,P)=1}^n i) \pmod {p^r}
于是我们就得到了。
C_n^m \bmod p^r = \frac{f(n)}{f(m)f(n-m)}p^{x-y-z}
我们最后一个问题就是求 x,y,z。定义 g(n) 表示 x。
g(n) = \left \lfloor \frac{n}{p} \right \rfloor + g(\left \lfloor \frac{n}{p} \right \rfloor)
综上求 f(n) 是 O(p\log_p n) 的,求 g(n) 是 O(\log_p n)。
最后用 CRT 即可。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define il inline
il ll rd(){
ll s = 0, w = 1;
char ch = getchar();
for (;ch < '0' || ch > '9'; ch = getchar()) if (ch == '-') w = -1;
for (;ch >= '0' && ch <= '9'; ch = getchar()) s = ((s << 1) + (s << 3) + ch - '0');
return s * w;
}
ll n, m;
int P;
int exgcd(int a, int b, int &x, int &y){
if (b == 0) return x = 1, y = 0, a;
int k = exgcd(b, a % b, y, x);
return y -= 1ll * a / b * x, k;
}
il int ksm(int x, ll r, int p){
int ans = 1;
for (; r; x = 1ll * x * x % P, r >>= 1) if (r & 1) ans = 1ll * ans * x % P;
return ans;
}
il int inv(int q, int p){
int x, y;
return exgcd(q, p, x, y), x += p, (x >= p ? x - p : x);
}
int fact(ll n, int pk, int p){
if (n == 0) return 1;
int ans = 1;
for (int i = 2; i <= pk; i++) if (i % p) ans = 1ll * ans * i % pk;
ans = ksm(ans, n / pk, pk);
for (int i = 2; i <= n % pk; i++) if (i % p) ans = 1ll * ans * i % pk;
return 1ll * ans * fact(n / p, pk, p) % pk;
}
il int CRT(int x, int p){return 1ll * x * inv(P / p, p) % P * (P / p) % P; }
il int C(ll n, ll m, int pk, int p){
ll s = fact(n, pk, p), x = fact(m, pk, p), y = fact(n - m, pk, p), k = 0;
for (ll i = n; i; i /= p) k += i / p;
for (ll i = m; i; i /= p) k -= i / p;
for (ll i = n - m; i; i /= p) k -= i / p;
return 1ll * s * inv(x, pk) % pk * inv(y, pk) % pk * ksm(p, k, pk) % pk;
}
il int exLcs(ll x, ll m, ll P){
int tmp = P, ans = 0;
for (ll i = 2; i * i <= tmp; i++){
int pk = 1;
while (tmp % i == 0) pk *= i, tmp /= i;
ans = (ans + CRT(C(n, m, pk, i), pk)) % P;
}
if (tmp > 1) ans = (ans + CRT(C(n, m, tmp, tmp), tmp)) % P;
return ans;
}
signed main(){
n = rd(), m = rd(), P = rd();
printf ("%lld\n", exLcs(n, m, P));
return 0;
}