题解 P4721 【【模板】分治 FFT】

· · 题解

我的博客

题目大意:给定长度为n-1的数组g_{[1,n)},求f_{[0,n)},要求:

f_i=\sum_{j=1}^if_{i-j}g_j f_0=1

题解:直接求复杂度是O(n^2),明显不可以通过此题

分治FFT,可以用CDQ分治,先求出f_{[l,mid)},可以发现这部分对区间的f_{[mid,r)}的贡献是f_{[l,mid)}*g_{[0,r-l)},卷出来加到对应位置就行了,复杂度O(n\log_2^2n)

有同学指出我这份代码有问题,我回去看了一下,发现读入g数组的地方写锅了,应该是读入g_1\sim g_{n-1},写成了g_1\sim g_n,然后因为是用了快读,没有发现问题。十分抱歉(感谢 @Lisy_03 指出错误)

C++ Code:

#include <algorithm>
#include <cstdio>
#include <cctype>
namespace std {
    struct istream {
#define M (1 << 21 | 3)
        char buf[M], *ch = buf - 1;
        inline istream() {
#ifndef ONLINE_JUDGE
            freopen("input.txt", "r", stdin);
#endif
            fread(buf, 1, M, stdin);
        }
        inline istream& operator >> (int &x) {
            while (isspace(*++ch));
            for (x = *ch & 15; isdigit(*++ch); ) x = x * 10 + (*ch & 15);
            return *this;
        }
#undef M
    } cin;
    struct ostream {
#define M (1 << 21 | 3)
        char buf[M], *ch = buf - 1;
        int w;
        inline ostream& operator << (int x) {
            if (!x) {
                *++ch = '0';
                return *this;
            }
            for (w = 1; w <= x; w *= 10);
            for (w /= 10; w; w /= 10) *++ch = (x / w) ^ 48, x %= w;
            return *this;
        }
        inline ostream& operator << (const char x) {*++ch = x; return *this;}
        inline ~ostream() {
#ifndef ONLINE_JUDGE
            freopen("output.txt", "w", stdout);
#endif
            fwrite(buf, 1, ch - buf + 1, stdout);
        }
#undef M
    } cout;
}

#define maxn 131072 | 3
const int mod = 998244353, G = 3;

namespace Math {
    inline int pw(int base, int p) {
        static int res;
        for (res = 1; p; p >>= 1, base = static_cast<long long> (base) * base % mod) if (p & 1) res = static_cast<long long> (res) * base % mod;
        return res;
    }
    inline int inv(int x) {return pw(x, mod - 2);}
}

int n;
int f[maxn], g[maxn];
namespace Poly {
#define N 131072 | 3
    int s, lim, ilim, rev[N];
    int Wn[N + 1];
    inline void reduce(int &x) {x += x >> 31 & mod;}
    inline void clear(register int *l, const int *r) {
        if (l >= r) return ;
        while (l != r) *l++ = 0;
    }
    inline void init(const int n) {
        s = -1, lim = 1; while (lim <= n) lim <<= 1, s++; ilim = Math::inv(lim);
        for (int i = 1; i < lim; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
        const int t = Math::pw(G, (mod - 1) / lim);
        *Wn = 1; for (register int *i = Wn; i != Wn + lim; ++i) *(i + 1) = static_cast<long long> (*i) * t % mod;
    }

    inline void NTT(int *A, const int op = 1) {
        for (register int i = 1; i < lim; i++) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
        for (register int mid = 1; mid < lim; mid <<= 1) {
            const int t = lim / mid >> 1;
            for (register int i = 0; i < lim; i += mid << 1) {
                for (register int j = 0; j < mid; j++) {
                    const int W = op ? Wn[t * j] : Wn[lim - t * j];
                    const int X = A[i + j], Y = static_cast<long long> (A[i + j + mid]) * W % mod;
                    reduce(A[i + j] += Y - mod), reduce(A[i + j + mid] = X - Y);
                }
            }
        }
        if (!op) for (int i = 0; i < lim; i++) A[i] = static_cast<long long> (A[i]) * ilim % mod;
    }

    int A[N], B[N];
    void CDQ_NTT(const int l, const int r) {
        if (r - l < 2) return ;
        const int mid = l + r >> 1;
        CDQ_NTT(l, mid); init(r - l);
        std::copy(f + l, f + mid, A); clear(A + mid - l, A + lim);
        std::copy(g, g + r - l, B); clear(B + r - l, B + lim);
        NTT(A), NTT(B);
        for (int i = 0; i < lim; i++) A[i] = static_cast<long long> (A[i]) * B[i] % mod;
        NTT(A, 0);
        for (int i = mid; i < r; i++) reduce(f[i] += A[i - l] - mod);
        CDQ_NTT(mid, r);
    }
#undef N
}

int main() {
    std::cin >> n;
    for (int i = 1; i < n; i++) std::cin >> g[i];
    *f = 1;
    Poly::CDQ_NTT(0, n);
    for (int i = 0; i < n; i++) std::cout << f[i] << ' ';
    std::cout << '\n';
    return 0;
}

既然我代码锅了,要重新审核,就再补一篇求逆的博客吧。(然后我求逆的博客又锅了,这是后话,感谢 @Owen_codeisking 指出错误)

我的博客

题解:发现这道题就是求f*g=f-1f-1就是没有常数项的f),改写一下式子:

\begin{aligned}f*g\equiv f-1\pmod{x^n}\\f-f*g\equiv1\pmod{x^n}\\f*(1-g)\equiv1\pmod{x^n}\\f\equiv(1-g)^{-1}\pmod{x^n}\end{aligned}

就可以求逆了

C++ Code:

#include <algorithm>
#include <cstdio>
#include <cctype>
namespace std {
    struct istream {
#define M (1 << 21 | 3)
        char buf[M], *ch = buf - 1;
        inline istream() {
#ifndef ONLINE_JUDGE
            freopen("input.txt", "r", stdin);
#endif
            fread(buf, 1, M, stdin);
        }
        inline istream& operator >> (int &x) {
            while (isspace(*++ch));
            for (x = *ch & 15; isdigit(*++ch); ) x = x * 10 + (*ch & 15);
            return *this;
        }
#undef M
    } cin;
    struct ostream {
#define M (1 << 21 | 3)
        char buf[M], *ch = buf - 1;
        int w;
        inline ostream& operator << (int x) {
            if (!x) {
                *++ch = '0';
                return *this;
            }
            for (w = 1; w <= x; w *= 10);
            for (w /= 10; w; w /= 10) *++ch = (x / w) ^ 48, x %= w;
            return *this;
        }
        inline ostream& operator << (const char x) {*++ch = x; return *this;}
        inline ~ostream() {
#ifndef ONLINE_JUDGE
            freopen("output.txt", "w", stdout);
#endif
            fwrite(buf, 1, ch - buf + 1, stdout);
        }
#undef M
    } cout;
}

const int mod = 998244353, G = 3;
namespace Math {
    inline int pw(int base, int p) {
        static int res;
        for (res = 1; p; p >>= 1, base = static_cast<long long> (base) * base % mod) if (p & 1) res = static_cast<long long> (res) * base % mod;
        return res;
    }
    inline int inv(int x) {return pw(x, mod - 2);}
}

inline void reduce(int &a) {a += a >> 31 & mod;}
inline long long get_reducell(long long a) {return a += a >> 63 & mod;}
namespace Poly {
#define N (262144 | 3)
    int lim, ilim, s, rev[N];
    int Wn[N + 1];
    inline void init(int n) {
        s = -1, lim = 1; while (lim <= n) lim <<= 1, s++; ilim = Math::inv(lim);
        for (register int i = 1; i < lim; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
        const int t = Math::pw(G, (mod - 1) / lim);
        *Wn = 1; for (register int *i = Wn; i != Wn + lim; ++i) *(i + 1) = static_cast<long long> (*i) * t % mod;
    }
    inline void clear(register int *l, const int *r) {
        if (l >= r) return ;
        while (l != r) *l++ = 0;
    }

    inline void NTT(int *A, const int op = 1) {
        for (register int i = 1; i < lim; i++) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
        for (register int mid = 1; mid < lim; mid <<= 1) {
            const int t = lim / mid >> 1;
            for (register int i = 0; i < lim; i += mid << 1) {
                for (register int j = 0; j < mid; j++) {
                    const int W = op ? Wn[t * j] : Wn[lim - t * j];
                    const int X = A[i + j], Y = static_cast<long long> (A[i + j + mid]) * W % mod;
                    reduce(A[i + j] += Y - mod), reduce(A[i + j + mid] = X - Y);
                }
            }
        }
        if (!op) for (register int *i = A; i != A + lim; ++i) *i = static_cast<long long> (*i) * ilim % mod;
    }

    int C[N];
    void INV(int *A, int *B, int n) {
        if (n == 1) {*B = Math::inv(*A); return ;}
        INV(A, B, n + 1 >> 1);
        init(n + n - 1);
        std::copy(A, A + n, C); clear(C + n, C + lim);
        NTT(B), NTT(C);
        for (register int i = 0; i < lim; i++) B[i] = (2 - static_cast<long long> (B[i]) * C[i] % mod + mod) % mod * B[i] % mod;
        NTT(B, 0), clear(B + n, B + lim);
    }
#undef N
}

#define maxn (262144 | 3)
int n, f[maxn], g[maxn];

int main() {
    std::cin >> n;
    *g = 1;
    for (int i = 1; i < n; i++) {
        std::cin >> g[i];
        g[i] = mod - g[i];
    }
    Poly::INV(g, f, n);
    for (int i = 0; i < n; i++) std::cout << f[i] << ' ';
    std::cout << '\n';
    return 0;
}