小K最近刚刚习得了一种非常酷炫的多项式求和技巧,可以对某几类特殊的多项式进行运算。
既然题面都说了要用“炫酷的多项式求和技巧”了,那我们为何不直接使用多项式求和技巧呢?
首先我们先化简题目的表达式:
S(n)=n∑k=0akf(k)=n∑k=0ak(m∑i=0bixi)=m∑i=0(n∑k=0akki)
而后面这个东西 ∑nk=0akkm,只要接触过的话就再熟悉不过了,指数函数乘以幂函数的级数求和也是很久之前就在研究的问题了,具体的证明需要用到生成函数等相关知识,具体篇幅太长就不在这里展开了,可以到 这里 来看一看证明。在知道多项式表达式之后,我们就可以用拉格朗日插值来完成具体值的运算。现有的最快做法是可以结合质数筛法的做法,复杂度为 O(m+logn) ,而这个做法又有些冗长,所以在这里用一个轻便好写的 O(mlogm+logn) 就可以了。
本题的数据范围 m 非常的小,实际上官方给出的标准答案也是结合dp+矩阵快速幂,复杂度在 O(m3logn) 就行了,而我们这里只需要 O(m2logm+mlogn), 一点点地暴力计算上述 m+1 项结果即可。
该做法加强到 n≤10100000 的时候依旧可行(如果只算单项幂数乘指数求和的话, m 也可以加强到 106 这个级别,在1s内完成运算),因为我们只需要知道 n 对 mod 和对 mod−1 取模的值即可。(mod−1 是因为 n 需要做为指数进行运算,此时就需要结合费马小定理了)。如果 mod 不是质数,那么我们就可以用中国剩余定理,对多个模质数意义下的结果再做总结。
但是如果 m 太大(比如大到 109 左右,导致时间和空间复杂度都会特别大),而 n 非常小的话,那就可以直接用最暴力的做法完成了。
那么不多说,上代码(C语言完成)
#include <stdio.h>
#include <string.h>
#define getchar getchar_unlocked
#define putchar putchar_unlocked
#define maxd 100
#define mod 1000000007
typedef long long ll;
ll rd()
{
ll k = 0;
char c = getchar();
while (c < '0' || c > '9')
c = getchar();
while (c >= '0' && c <= '9')
{
k = (k << 1) + (k << 3) + (c ^ 48);
c = getchar();
}
return k;
}
void wr(ll x)
{
if (x > 9)
wr(x / 10);
putchar(x % 10 + '0');
}
void swap(int *a, int *b)
{
if ((*a) != (*b))
(*a) ^= (*b), (*b) ^= (*a), (*a) ^= (*b);
}
int _inv(int a)
{
int _a = a, _b = mod, u = 1, v = 0;
while (_b)
{
int t = _a / _b;
_a -= t * _b;
swap(&_a, &_b);
u -= t * v;
swap(&u, &v);
}
u += mod, u %= mod;
return u;
}
int quick_pow(int x, ll p)
{
ll ans = 1;
while (p)
{
if (p & 1)
ans = ans * x % mod;
x = 1ll * x * x % mod;
p >>= 1;
}
return ans;
}
int inv[maxd + 10], cur_max_d;
int s1[maxd + 10], s2[maxd + 10];
void init_all() { inv[1] = 1, cur_max_d = 1; }
void init_inv(int d)
{
if (cur_max_d < d || d == 1)
{
for (int i = cur_max_d + 1; i <= d + 1; ++i)
inv[i] = 1ll * inv[mod % i] * (mod - mod / i) % mod;
cur_max_d = d;
}
}
void init_case_1(int d) { memset(s1, 0, sizeof(s1[0]) * (d + 1)); }
void init_case_2(int d) { memset(s1, 0, sizeof(s1[0]) * (d + 1)), memset(s2, 0, sizeof(s2[0]) * (d + 1)); }
// \sum_{i=0}^{n-1} {(r^i)*(i^d)}
int sum_of_exp_time_poly(int r, int d, ll n)
{
init_inv(d);
int ans = 0;
ll mod_n = n, pow_n = n;
// 本题 n 的范围并不超过质数,所以后面这些取模都不需要了
// ll mod_n = n % mod, pow_n = n % (mod - 1);
// int len = strlen(n);
/*
if (len < 20)
{
for (int i = 0; i < len; ++i)
mod_n = (mod_n << 1) + (mod_n << 3) + (n[i] ^ '0');
pow_n = mod_n % (mod - 1), mod_n %= mod;
}
else
{
for (int i = 0; i < len; ++i)
{
mod_n = (mod_n << 1) + (mod_n << 3) + (n[i] ^ '0');
pow_n = (pow_n << 1) + (pow_n << 3) + (n[i] ^ '0');
mod_n = (mod_n >= mod) ? mod_n % mod : mod_n;
pow_n = (pow_n >= (mod - 1)) ? pow_n % (mod - 1) : pow_n;
}
}
*/
if (r == 1)
{
init_case_1(d);
for (int i = 0, t = 0, b = 1; i <= d; ++i)
{
b = 1ll * b * (mod_n + mod - i) % mod * inv[i + 1] % mod;
t = (t + quick_pow(i, d)) % mod;
s1[i] = 1ll * t * b % mod;
}
for (int i = 0, b = 1; i <= d; ++i)
{
ans = (ans + 1ll * b * ((i & 1) ? mod - s1[d - i] : s1[d - i])) % mod;
b = 1ll * b * (mod_n + mod - 1 - (d - i)) % mod * inv[i + 1] % mod;
}
}
else
{
init_case_2(d);
int t1 = 0, t2 = 0;
for (int i = 0, rpow = 1; i <= d; ++i, rpow = 1ll * rpow * r % mod)
{
s1[i] = t1 = (t1 + 1ll * quick_pow(i, d) * rpow) % mod;
s2[i] = t2 = (t2 + 1ll * quick_pow(i + mod_n, d) * rpow) % mod;
}
int ans1 = 0, ans2 = 0, b = 1, mr = mod - r;
for (int i = 0; i <= d; ++i)
{
ans1 = (ans1 + 1ll * b * s1[d - i]) % mod;
ans2 = (ans2 + 1ll * b * s2[d - i]) % mod;
b = 1ll * b * mr % mod * (d + 1 - i) % mod * inv[i + 1] % mod;
}
ans = ans1 + mod - 1ll * quick_pow(r, pow_n) * ans2 % mod;
ans = 1ll * ans * _inv(quick_pow(mod + 1 - r, d + 1)) % mod;
}
return ans;
}
int b[maxd + 10];
int main()
{
init_all();
ll n = rd() + 1;
int m = rd(), a = rd();
for (int i = 0; i <= m; ++i)
b[i] = rd();
ll ans = 0;
for (int i = 0; i <= m; ++i)
{
ans += 1ll * b[i] * sum_of_exp_time_poly(a, i, n) % mod;
if (ans >= mod)
ans %= mod;
}
wr(ans);
}
加点注释吧,还是看不太懂啊