题解 P5277 【【模板】多项式开根(加强版)】

ikka

2019-03-30 16:43:31

题解

Update

两处公式挂了换成了图片。

前置知识

多项式开根

模板在P5205 【模板】多项式开根 。

这里不多做讨论。

Cipolla算法

求解x^2\equiv n \pmod p的算法,由于ntt的模是奇素数所以这里只讨论p为奇素数的情况。

当方程有解时称n在模p意义下是二次剩余,否则称n在模p意义下是非二次剩余。

我们定义勒让德符号

有一个欧拉判别准则

(\frac{n}{p})\equiv n^{\frac{p-1}{2}}\pmod p

口胡证明:

  1. p|n时,显然有0\equiv n^{\frac{p-1}{2}}\pmod p
  2. 假设x^2\equiv n \pmod p,若x^{p-1}\equiv 1\pmod p时,根据费马小定理,假设成立,n在模p意义下是二次剩余。 若x^{p-1}\equiv -1\pmod p时,根据费马小定理,假设不成立,n在模p意义下是二次非剩余。

接下来让我们证明三个结论。

结论一:n^2 \equiv (p-n)^2 \pmod p

证明展开式子中的(p-n)^2即可。

结论二:有\frac{p-1}{2}个数是模p意义下的二次非剩余。

根据结论一,可以得出0p-1的平方中有\frac{p+1}{2}个不同的数,剩下的数即为模p意义下的二次非剩余。

结论三:(a+b)^p\equiv a^p+b^p \pmod p

二项式定理展开,由于p是素数,所以除了首尾两项其他项中的组合数中都有质因子p,故同余式成立。

Cipolla算法流程

第一步,判断原方程是否有解(利用欧拉判别准则)。

第二步,随机找到一个a使得\omega\equiv(a^2-n)\pmod p在模p意义下是二次非剩余。(根据结论二,期望随机次数为2)

第三步,找到一个解x\equiv(a+\sqrt{\omega})^{\frac{p+1}{2}} \pmod p。但是\sqrt{\omega}显然不存在,我们可以把\sqrt{\omega}当做虚数i来看,就把问题解决了。

证明:

Code

#include <cstdio>
#include <cstring>
#include <ctime>
#include <cstdlib>
#include <algorithm>
const int maxn = 400010;
const int mod = 998244353;
const int g = 3;
const int invg = 332748118;
const int inv2 = 499122177;

int inline pls(int a, int b) { int m = a + b; return m < mod ? m : m - mod; }
int inline dec(int a, int b) { int m = a - b; return m + ((m >> 31) & mod); }
int inline mul(int a, int b) { return 1ll * a * b % mod; }
int inline pow(int a, int b) {
  int ans = 1;
  while (b) {
    if (b & 1) ans = mul(ans, a);
    a = mul(a, a);
    b >>= 1;
  }
  return ans;
}

struct cp {
  int r, i;
  cp(int x = 0, int y = 0) : r(x), i(y) {}
};

cp inline mul(cp a, cp b, int w) {
  return cp(pls(mul(a.r, b.r), mul(w, mul(a.i, b.i))), pls(mul(a.r, b.i), mul(a.i, b.r)));
}

int inline pow(cp a, int b, int w) {
  cp ans(1, 0);
  while (b) {
    if (b & 1) ans = mul(ans, a, w);
    a = mul(a, a, w);
    b >>= 1;
  }
  return ans.r;
}

int inline cipolla(int x) {
  srand(time(0));
  if (pow(x, (mod - 1) >> 1) == mod - 1) return -1;
  while (true) {
    int a = mul(rand(), rand());
    int w = dec(mul(a, a), x);
    if (pow(w, (mod - 1) >> 1) == mod - 1) {
      return pow(cp(a, 1), (mod + 1) >> 1, w);
    }
  }
}

int a[maxn], b[maxn], r[maxn];

void inline ntt(int *a, int n, int f) {
  for (int i = 0; i < n; ++i) if (i < r[i]) std::swap(a[i], a[r[i]]);
  for (int i = 1; i < n; i <<= 1) {
    int wn = pow(f ? g : invg, (mod - 1) / (i << 1));
    for (int *j = a; j < a + n; j += i << 1) {
      int w = 1;
      for (int k = 0; k < i; ++k, w = mul(w, wn)) {
        int x = *(j + k), y = mul(w, *(i + j + k));
        *(j + k) = pls(x, y), *(i + j + k) = dec(x, y);
      }
    }
  }
  if (!f) {
    int rv = pow(n, mod - 2);
    for (int *i = a; i < a + n; ++i) *i = mul(*i, rv);
  }
}

void inline inv(int *a, int *b, int n) {
  b[0] = pow(a[0], mod - 2);
  static int A[maxn], B[maxn], len, lim;
  for (len = 1; len < n << 1; len <<= 1) {
    lim = len << 1;
    memcpy(A, a, len << 2); memcpy(B, b, len << 2);
    for (int i = 1; i < lim; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) ? len : 0);
    ntt(A, lim, 1); ntt(B, lim, 1);
    for (int i = 0; i < lim; ++i) b[i] = dec((B[i] << 1) % mod, mul(A[i], mul(B[i], B[i])));
    ntt(b, lim, 0);
    memset(b + len, 0, len << 2);
  }
  memset(A, 0, len << 2); memset(B, 0, len << 2);
  memset(b + n, 0, (len - n) << 2);
}

void inline sqrt(int *a, int *b, int n) {
  int sr = cipolla(a[0]);
  b[0] = std::min(sr, mod - sr);
  static int A[maxn], B[maxn], len, lim;
  for (len = 1; len < n << 1; len <<= 1) {
    lim = len << 1;
    memcpy(A, a, len << 2);
    inv(b, B, len);
    for (int i = 1; i < lim; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) ? len : 0);
    ntt(A, lim, 1); ntt(B, lim, 1);
    for (int i = 0; i < lim; ++i) A[i] = mul(A[i], B[i]);
    ntt(A, lim, 0);
    for (int i = 0; i < len; ++i) b[i] = mul(inv2, pls(A[i], b[i]));
    memset(b + len, 0, len << 2);
  }
  memset(A, 0, len << 2); memset(B, 0, len << 2);
  memset(b + n, 0, (len - n) << 2);
}

int main() {
  int n;
  scanf("%d", &n);
  for (int *i = a; i < a + n; ++i) scanf("%d", i);
  sqrt(a, b, n);
  for (int *i = b; i < b + n; ++i) printf("%d%c", *i, " \n"[i == b + n - 1]);
  return 0;
}

抄我的代码会TLE233