题目描述
给定一个 $n-1$ 次多项式 $A(x)$,求一个在 ${} \bmod x^n$ 意义下的多项式 $B(x)$,使得 $B^2(x) \equiv A(x) \pmod{x^n}$。若有多解,请取零次项系数较小的作为答案。
多项式的系数在 ${}\bmod 998244353$ 的意义下进行运算。
过程
我们假设 $n$ 为 $2$ 的正整数次幂,不是的话用 $0$ 补足。
然后递归求解,设我们现在已经得到了多项式 $B’(x)$,满足 $B’^2(x) \equiv A(x) \pmod{x^{\frac{z}{2}}}$。
然后我们需要的是一个多项式 $B(x)$,满足 $B^2(x) \equiv A(x) \pmod{x^z}$。
$$B(x) - B’(x) \equiv 0 \pmod{x^{\frac{z}{2}}}$$
$$(B(x) - B’(x))^2 \equiv 0 \pmod{x^z}$$
$$B^2(x) - 2 B(x) B’(x) + B’^2(x) \equiv 0\pmod{x^z}$$
$$B(x) \equiv \frac{A(x) + B’^2(x)}{2B’(x)} \pmod{x^z}$$
于是求个多项式求逆即可,初始化用二次剩余,时间复杂度 $O(n\log n)$。
代码
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1000010, M = 25, g = 3, gi = 332748118, inv2 = 499122177, mod = 998244353;
ll a[N];
ll I_mul_I;
struct Complex {
ll a, b;
Complex operator* (Complex y) {
return {(a * y.a % mod + b * y.b % mod * I_mul_I % mod) % mod, (a * y.b % mod + b * y.a % mod) % mod};
}
};
ll qpow(ll a, ll b) {
ll res = 1;
while (b) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
Complex qpow(Complex x, ll y) {
Complex res = {1, 0};
while (y) {
if (y & 1) res = res * x;
x = x * x;
y >>= 1;
}
return res;
}
bool legendre(ll x) {
return qpow(x, (mod - 1) >> 1) == 1;
}
ll cipolla(ll n) {
n %= mod;
if (!n) return 0;
if (!legendre(n)) return -1;
ll a = rand() % mod;
while (legendre((a * a + mod - n) % mod)) a = rand() % mod;
I_mul_I = (a * a + mod - n) % mod;
return qpow(Complex({a, 1}), (mod + 1) >> 1).a;
}
int rev[1 << M];
void init(int n) {
int lgn = (int)log2(n);
for (int i = 0; i < n; i ++ ) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lgn - 1));
}
void ntt(ll *a, int n, int Inv = 0) {
for (int i = 0; i < n; i ++ ) if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int i = 1; i < n; i <<= 1) {
ll gn = qpow(Inv ? gi : g, (mod - 1) / (i << 1));
for (int j = 0; j < n; j += (i << 1)) {
ll gk = 1;
for (int k = 0; k < i; k ++ , gk = gk * gn % mod) {
ll x = a[j + k], y = gk * a[i + j + k] % mod;
a[j + k] = (x + y) % mod, a[i + j + k] = (x - y + mod) % mod;
}
}
} if (Inv) {
ll inv = qpow(n, mod - 2);
for (int i = 0; i < n; i ++ ) a[i] = a[i] * inv % mod;
}
}
ll b[N], t[N];
void Inv(ll *a, int n) {
memset(t, 0, sizeof t);
t[0] = qpow(a[0], mod - 2);
int k = 1;
do {
k <<= 1; init(k << 1);
for (int i = 0; i < k; i ++ ) b[i] = a[i];
for (int i = k; i < (k << 1); i ++ ) b[i] = 0;
ntt(b, k << 1); ntt(t, k << 1);
for (int i = 0; i < (k << 1); i ++ ) t[i] = (t[i] * 2 - t[i] * t[i] % mod * b[i] % mod + mod) % mod;
ntt(t, k << 1, 1);
for (int i = k; i < (k << 1); i ++ ) t[i] = 0;
} while (k < n);
for (int i = 0; i < n; i ++ ) a[i] = t[i];
}
ll t2[N], b2[N], c2[N];
void Sqrt(ll *a, int n) {
memset(t2, 0, sizeof t2);
ll x = cipolla(a[0]);
t2[0] = min(x, mod - x);
int k = 1;
do {
k <<= 1;
for (int i = 0; i < k; i ++ ) b2[i] = t2[i], c2[i] = a[i];
for (int i = k; i < (k << 1); i ++ ) b2[i] = c2[i] = 0;
Inv(b2, k); init(k << 1);
ntt(b2, k << 1), ntt(c2, k << 1), ntt(t2, k << 1);
for (int i = 0; i < (k << 1); i ++ ) t2[i] = (b2[i] * c2[i] % mod + t2[i]) % mod * inv2 % mod;
ntt(t2, k << 1, 1);
for (int i = k; i < (k << 1); i ++ ) t2[i] = 0;
} while (k < n);
for (int i = 0; i < n; i ++ ) a[i] = t2[i];
}
int main() {
int n;
scanf("%d", &n);
for (int i = 0; i < n; i ++ ) scanf("%lld", &a[i]);
Sqrt(a, n);
for (int i = 0; i < n; i ++ ) printf("%lld ", a[i]);
return 0;
}