算法
(NTT, 生成函数) $O(m \log m)$
设第 $i$ 个数为 $v _ i$。(改个名字看着舒服仅此而已=,=)
设从中选出若干数,总和为 $m$ 的方案数为 $a _ m$。
设 $\{a\}$ 的生成函数 $\begin {aligned} f(x) = \sum _ {i = 0} ^ {+ \infty} a _ i x ^ i \end {aligned}$。
那么有 $\begin {aligned} f(x) = \prod _ {i = 1} ^ n (1 + x ^ {v _ i}) \end {aligned}$。
乘积不好算,用 $\exp$ 和 $\ln$ 将乘积转为求和。
$$ \begin {aligned} f(x) = & \prod _ {i = 1} ^ n (1 + x ^ {v _ i}) \\\ = & \exp(\ln \prod _ {i = 1} ^ n (1 + x ^ {v _ i})) \\\ = & \exp(\sum _ {i = 1} ^ n \ln(1 + x ^ {v _ i})) \\\ \end {aligned} $$
引理:$\begin {aligned} \ln(1 + t) = \sum _ {k = 1} ^ {+ \infty} (-1) ^ {k + 1} \frac {t ^ k} k \end {aligned}$(其实就是泰勒展开啦)
非泰勒展开证明:
$$ \begin {aligned} \ln(1 + t) & = \int \frac 1 {1 + t} \text d t \\\ & = \int \sum _ {k = 0} ^ {+ \infty} (-t) ^ k \text d t \\\ & = \sum _ {k = 0} ^ {+ \infty} (-1) ^ k \int t ^ k \text d t \\\ & = \sum _ {k = 0} ^ {+ \infty} (-1) ^ k \frac {t ^ {k + 1}} {k + 1} \\\ & = \sum _ {k = 1} ^ {+ \infty} (-1) ^ {k + 1} \frac {t ^ k} k \\\ \end {aligned} $$
继续推式子:
$$ \begin {aligned} f(x) = & \exp(\sum _ {i = 1} ^ n \ln(1 + x ^ {v _ i})) \\\ = & \exp(\sum _ {i = 1} ^ n \sum _ {k = 1} ^ {+ \infty} (-1) ^ {k + 1} \frac {x ^ {v _ i k}} k) \\\ \end {aligned} $$
到这一步其实就已经可以直接做了。我们可以将 $\exp$ 中的多项式求出来,然后多项式 $\exp$ 即可。
求 $\exp$ 中多项式代码如下:
(这几段代码为了方便理解而补充的,不要乱抄,乱抄会导致爆 int
等各种错误发生)
for (int i = 1; i <= n; i ++ )
for (int k = 1; v[i] * k <= m; k ++ )
if (k & 1)
a[v[i] * k] += Inv(k);
else
a[v[i] * k] -= Inv(k);
注意到这部分时间复杂度为 $\begin {aligned} O(\sum _ {i = 1} ^ n \frac m {v _ i}) \end {aligned}$,如果每个 $v _ i$ 的值都是 $1$(虽然数据中好像并没有这种情况),复杂度会退化到 $O(m n)$。
注意到 $v _ i$ 的值不会很大,我们可以将每种 $v _ i$ 的取值统计出来。
设 $t _ x$ 表示有多少个 $v _ i = x$,上面代码可以优化至如下代码:
for (int i = 1; i <= n; i ++ ) t[v[i]] ++ ;
for (int i = 1; i <= m; i ++ )
if (t[i])
for (int k = 1; i * k <= m; k ++ )
if (k & 1)
a[i * k] += t[i] * Inv(k);
else
a[i * k] -= t[i] * Inv(k);
这段代码时间复杂度为 $ \begin {aligned} O(\sum _ {i = 1} ^ m \frac m i) = O(m \ln m) \end {aligned} $
多项式 $\exp$ 做法复杂度为 $O(m \log m)$,所以总复杂度为 $O(m \log m)$。
由于数据范围,优化的效果并不明显(其实是退化了)。当然这里面也有常数原因。
就当是拓展一下思路 & 复习 NTT 吧。
细节处理参考代码,懒得注释了。
凭感觉可知答案不会太大,这里取 $\bmod = 998244353$,如果哪天挂了请告诉我。
C++ 代码
#include <cstdio>
#include <cstring>
typedef long long ll;
const int N = 40005;
const int mod = 998244353, G = 3;
int n, m;
int last = -1, lim, rev[N];
int a[N], b[N], rt[N];
inline void bmod(int& x) {x += x >> 31 & mod;}
inline void swap(int& a, int& b) {a ^= b ^= a ^= b;}
inline void reverse(int* st, int* ed) {while (st < --ed) swap(*st++, *ed);}
int inv(int x, int k = mod - 2)
{
int r = 1;
while (k)
{
if (k & 1) r = (ll)x * r % mod;
x = (ll)x * x % mod;
k >>= 1;
}
return r;
}
struct NT
{
int inverse[N];
void prework(int n)
{
inverse[1] = 1;
for (int i = 2; i <= n; i ++ )
bmod(inverse[i] = -(ll)inverse[mod % i] * (mod / i) % mod);
}
int operator() (const int x) {return inverse[x];}
} Inv;
/*****NTT板子*****/
void prework(int n)
{
if (last == n) return;
last = n, lim = 1;
while (lim <= n) lim <<= 1;
for (int i = 0; i < lim; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1 ? lim >> 1 : 0);
for (int k = 2, i; k <= lim; k <<= 1)
{
rt[i = k >> 1] = 1;
ll v = inv(G, (mod - 1) / k);
for (int j = i + 1; j < k; ++j) rt[j] = rt[j - 1] * v % mod;
}
}
void NTT(int* f, bool flag = true)
{
if (!flag) reverse(f + 1, f + lim);
for (int i = 0; i < lim; i ++ )
if (i < rev[i])
swap(f[i], f[rev[i]]);
int x, *bf;
for (int k = 2, len, i, j; k <= lim; k <<= 1)
for (len = k >> 1, i = 0; i < lim; i += k)
for (bf = rt + len, j = i; j < i + len; j ++, bf ++ )
{
x = (ll)f[len + j] * *bf % mod;
bmod(f[len + j] = f[j] - x);
bmod(f[j] += x - mod);
}
if (flag) return;
ll v = inv(lim);
for (int i = 0; i < lim; i ++ ) f[i] = f[i] * v % mod;
}
void fmul(int n, int m, int* f, int* g, int* res)
{
static int tf[N], tg[N]; prework(n + m);
memcpy(tf, f, n << 2), memset(tf + n, 0, lim - n << 2), NTT(tf);
memcpy(tg, g, m << 2), memset(tg + m, 0, lim - m << 2), NTT(tg);
for (int i = 0; i < lim; i ++ ) res[i] = (ll)tf[i] * tg[i] % mod;
NTT(res, false);
}
void finv(int n, int* f, int* g)
{
static int tf[N];
if (n == 1) return (void)(g[0] = inv(f[0]));
finv(n + 1 >> 1, f, g), prework(n << 1);
memcpy(tf, f, n << 2);
memset(tf + n, 0, lim - n << 2);
memset(g + n, 0, lim - n << 2);
NTT(tf), NTT(g);
for (int i = 0; i < lim; i ++ ) bmod(g[i] = (2 - (ll)tf[i] * g[i] % mod) * g[i] % mod);
NTT(g, false);
memset(g + n, 0, lim - n << 2);
}
void fln(int n, int* f, int* g)
{
static int tf[N];
finv(n, f, g);
for (int i = 0; i < n; i ++ ) tf[i] = (i + 1ll) * f[i + 1] % mod;
fmul(n, n, tf, g, tf);
for (int i = 1; i < n; i ++ ) g[i] = (ll)tf[i - 1] * Inv(i) % mod;
g[0] = 0;
}
void fexp(int n, int* f, int* g)
{
static int tf[N];
if (n == 1) return (void)(g[0] = 1);
fexp(n + 1 >> 1, f, g);
memset(tf, 0, n << 2);
fln(n, g, tf), --tf[0];
for (int i = 0; i < n; i ++ ) tf[i] = (f[i] - tf[i] + mod) % mod;
fmul(n, n, g, tf, g);
memset(g + n, 0, lim - n << 2);
}
/**********/
int t[N];
int main()
{
scanf("%d%d", &n, &m);
Inv.prework(m);
for (int i = 0; i < n; i ++ )
{
int v;
scanf("%d", &v);
t[v] ++ ;
}
for (int i = m; i; --i)
if (t[i])
for (int k = 1; i * k <= m; ++k)
{
int& x = a[i * k], y = (ll)t[i] * Inv(k) % mod;
bmod(x += k & 1 ? y - mod : -y);
}
fexp(m + 1, a, b);
printf("%d\n", b[m]);
return 0;
}
$$ $$
# 我看不懂, 但我大为震撼
假的抽风也来了# 我看不懂, 但我不震撼,躺平了
# $\huge{我看不懂, 但我大为震撼}$
# $\huge{我看不懂, 但我大为震撼}$
# $\huge{我看不懂, 但我大为震撼}$
# $\huge{我看不懂, 但我大为震撼}$
# $\huge{我看不懂, 但我大为震撼}$
Orz
# 我看不懂,但我大为震撼
###我看不懂, 但我大为震撼
不是哥们😡
我看不懂,但我大为震撼
我看不懂,但我大为震撼
# 我看不懂,但我大为震撼
# 我看不懂,但我大为震撼 #
这种求 $ \prod {1+x^A_n} $的问题都能这么做吗(
我看不懂,但我大为震撼
我真的看不懂,但我大为震撼
因为评论我给你顶一个
我看不懂, 但我大为震撼
## 虽然我看不懂,但我大为震撼
我看不懂,但我大为震撼
我看不懂,但我大为震撼
# 我看不懂, 但我大为震撼
我看不懂,但我大为震撼
你干嘛,哎哟