ikka
2019-03-30 16:43:31
Update
两处公式挂了换成了图片。
模板在P5205 【模板】多项式开根 。
这里不多做讨论。
求解ntt
的模是奇素数所以这里只讨论
当方程有解时称
我们定义勒让德符号:
有一个欧拉判别准则:
口胡证明:
接下来让我们证明三个结论。
结论一:
证明展开式子中的
结论二:有
根据结论一,可以得出
结论三:
二项式定理展开,由于
第一步,判断原方程是否有解(利用欧拉判别准则)。
第二步,随机找到一个
第三步,找到一个解
证明:
#include <cstdio>
#include <cstring>
#include <ctime>
#include <cstdlib>
#include <algorithm>
const int maxn = 400010;
const int mod = 998244353;
const int g = 3;
const int invg = 332748118;
const int inv2 = 499122177;
int inline pls(int a, int b) { int m = a + b; return m < mod ? m : m - mod; }
int inline dec(int a, int b) { int m = a - b; return m + ((m >> 31) & mod); }
int inline mul(int a, int b) { return 1ll * a * b % mod; }
int inline pow(int a, int b) {
int ans = 1;
while (b) {
if (b & 1) ans = mul(ans, a);
a = mul(a, a);
b >>= 1;
}
return ans;
}
struct cp {
int r, i;
cp(int x = 0, int y = 0) : r(x), i(y) {}
};
cp inline mul(cp a, cp b, int w) {
return cp(pls(mul(a.r, b.r), mul(w, mul(a.i, b.i))), pls(mul(a.r, b.i), mul(a.i, b.r)));
}
int inline pow(cp a, int b, int w) {
cp ans(1, 0);
while (b) {
if (b & 1) ans = mul(ans, a, w);
a = mul(a, a, w);
b >>= 1;
}
return ans.r;
}
int inline cipolla(int x) {
srand(time(0));
if (pow(x, (mod - 1) >> 1) == mod - 1) return -1;
while (true) {
int a = mul(rand(), rand());
int w = dec(mul(a, a), x);
if (pow(w, (mod - 1) >> 1) == mod - 1) {
return pow(cp(a, 1), (mod + 1) >> 1, w);
}
}
}
int a[maxn], b[maxn], r[maxn];
void inline ntt(int *a, int n, int f) {
for (int i = 0; i < n; ++i) if (i < r[i]) std::swap(a[i], a[r[i]]);
for (int i = 1; i < n; i <<= 1) {
int wn = pow(f ? g : invg, (mod - 1) / (i << 1));
for (int *j = a; j < a + n; j += i << 1) {
int w = 1;
for (int k = 0; k < i; ++k, w = mul(w, wn)) {
int x = *(j + k), y = mul(w, *(i + j + k));
*(j + k) = pls(x, y), *(i + j + k) = dec(x, y);
}
}
}
if (!f) {
int rv = pow(n, mod - 2);
for (int *i = a; i < a + n; ++i) *i = mul(*i, rv);
}
}
void inline inv(int *a, int *b, int n) {
b[0] = pow(a[0], mod - 2);
static int A[maxn], B[maxn], len, lim;
for (len = 1; len < n << 1; len <<= 1) {
lim = len << 1;
memcpy(A, a, len << 2); memcpy(B, b, len << 2);
for (int i = 1; i < lim; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) ? len : 0);
ntt(A, lim, 1); ntt(B, lim, 1);
for (int i = 0; i < lim; ++i) b[i] = dec((B[i] << 1) % mod, mul(A[i], mul(B[i], B[i])));
ntt(b, lim, 0);
memset(b + len, 0, len << 2);
}
memset(A, 0, len << 2); memset(B, 0, len << 2);
memset(b + n, 0, (len - n) << 2);
}
void inline sqrt(int *a, int *b, int n) {
int sr = cipolla(a[0]);
b[0] = std::min(sr, mod - sr);
static int A[maxn], B[maxn], len, lim;
for (len = 1; len < n << 1; len <<= 1) {
lim = len << 1;
memcpy(A, a, len << 2);
inv(b, B, len);
for (int i = 1; i < lim; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) ? len : 0);
ntt(A, lim, 1); ntt(B, lim, 1);
for (int i = 0; i < lim; ++i) A[i] = mul(A[i], B[i]);
ntt(A, lim, 0);
for (int i = 0; i < len; ++i) b[i] = mul(inv2, pls(A[i], b[i]));
memset(b + len, 0, len << 2);
}
memset(A, 0, len << 2); memset(B, 0, len << 2);
memset(b + n, 0, (len - n) << 2);
}
int main() {
int n;
scanf("%d", &n);
for (int *i = a; i < a + n; ++i) scanf("%d", i);
sqrt(a, b, n);
for (int *i = b; i < b + n; ++i) printf("%d%c", *i, " \n"[i == b + n - 1]);
return 0;
}
抄我的代码会TLE233