深入了解分治FFT——DP优化的黑科技

· · 算法·理论

问题提出

算法应用于问题,分治 FFT 的出现是为了解决这样一个问题:

给定序列 g_{1\dots n - 1},求序列 f_{0\dots n - 1}

其中 f_i=\sum_{j=1}^if_{i-j}g_j,边界为 f_0=1

具体可以见 【模板】分治 FFT - 洛谷

对于这个问题我们要求做到 \Theta(n\log^2n) 的时间复杂度。

问题解决

理论分析

对于这个问题,我们使用 CDQ 分治的思想。

我们令 solve(l,r) 为计算出将 g[l,r) 内的一段的子问题的函数。

为了方便,我们令最开始一定是调用 solve(0,1 << k) 的形式。

注意这里算的不是 f_i 而是这一段内自己对自己的贡献。

此时,我们需要进行分治。

首先,用 mid 将其分为 [l,mid)[mid,r) 两段。

为了方便我们记 len=r-l

对于左边一段可以直接递归向下计算。

然而对于右边一段,还要考虑左侧向右侧的贡献,因为当 j 走遍 r-l 的时候,i-j 可能跑到左侧的一段里来。

我们现在就思考如何计算这个问题的贡献。

我们将左侧区间对右侧区间中的 f_i 的贡献 x_i 写出来:

x_i=\sum_{j=\frac{len}{2}+1}^{len} f_{i-j}g_j

我们将 f[l,mid) 内的数组拎出来记为 h_i=f_{i-l}

x[mid,r) 内的数组拎出来记为 y_i=x_{i-mid}

所以:

y_i=\sum_{j=\frac{len}{2}+1}^{len} h_{i-j+\frac{len}{2}}g_j

可以写为:

y_i=\sum_{p+q=i+\frac{len}{2}} h_pg_q

发现这是一个和卷积!

我们令 F(x)=\sum y_ix^i,G(x)=\sum g_ix^i,H(x)=\sum h_ix^i

于是:

F(x)=G(x)H(x)

多项式乘法,使用 FFT 进行计算即可。

到这里我们就完成了在 \Theta(n\log^2n) 的时间复杂度内计算该问题的方法。

示例代码

/*
 * @Author: LightningCreeper [email protected]
 * @Date: 2024-12-30 18:30:22
 * @LastEditors: LightningCreeper [email protected]
 * @LastEditTime: 2024-12-30 20:37:17
 * @FilePath: /i.省队训练/CDQFFT.cpp
 * @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
 */
#include <bits/stdc++.h>
using namespace std;

#define int long long
#define endl '\n'
#define debug(x) cerr << #x << " = " << x << endl
#define gn(u, v) for (int v : G.G[u])
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define sz(x) (int)(x).size()
#define pii pair<int, int>
#define vi vector<int>
#define vpii vector<pii>
#define vvi vector<vi>
#define no cout << "NO" << endl
#define yes cout << "YES" << endl
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define tomin(x, y) ((x) = min((x), (y)))
#define tomax(x, y) ((x) = max((x), (y)))
#define ck(mask, i) (((mask) >> (i)) & 1)
#define pq priority_queue
#define FLG (cerr << "Alive!" << endl);

constexpr int MAXN = (1 << 20) + 5;
constexpr int MOD = 998244353;
constexpr int G = 3;
constexpr int GINV = 332748118;

int qpow(int x, int y) {
    int ans = 1;
    while (y) {
        if (y & 1)
            ans = ans * x % MOD;
        x = x * x % MOD;
        y >>= 1;
    }
    return ans;
}

int rt[MAXN];

void NTT(int x[], int len, bool t) {
    for (int i = 0; i < len; i++)
        if (i < rt[i])
            swap(x[i], x[rt[i]]);
    for (int i = 1; i < len; i *= 2) {
        int n;
        if (t)
            n = G;
        else
            n = GINV;
        n = qpow(n, (MOD - 1) / (i * 2));

        for (int j = 0; j < len; j += (i * 2)) {
            for (int k = 0, t = 1; k < i; k++, t = (t * n) % MOD) {
                int p = x[j + k];
                int q = t * x[j + k + i] % MOD;
                x[j + k] = (p + q);
                x[j + k + i] = p - q + MOD;
                if (x[j + k] >= MOD)
                    x[j + k] -= MOD;
                if (x[j + k + i] >= MOD)
                    x[j + k + i] -= MOD;
            }
        }
    }
}

void Times(int a[], int b[], int len) {
    int l = __lg(len);
    for (int i = 0; i < len; i++)
        rt[i] = 0;
    for (int i = 0; i < len; i++)
        rt[i] = (rt[i >> 1] >> 1) | ((i & 1) << (l - 1));

    NTT(a, len, true);
    NTT(b, len, true);
    for (int i = 0; i < len; i++)
        a[i] = a[i] * b[i] % MOD;
    NTT(a, len, false);
    NTT(b, len, false);
    int inv = qpow(len, MOD - 2);

    for (int i = 0; i < len; i++) {
        a[i] = a[i] * inv % MOD;
        b[i] = b[i] * inv % MOD;
    }
}

int n;
int tmp[MAXN], tmp2[MAXN], f[MAXN], g[MAXN];

void CDQ(int l, int r) {
    if (l + 1 == r)
        return;

    int mid = l + r >> 1;
    int len = r - l;
    CDQ(l, mid);

    for (int i = 0; i < len * 2; i++)
        tmp[i] = 0;
    for (int i = l; i < mid; i++)
        tmp[i - l] = f[i];
    Times(tmp, g, len * 2);

    for (int i = mid; i < r; i++)
        (f[i] += tmp[i - l]) %= MOD;
    CDQ(mid, r);
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    cin >> n;
    for (int i = 1; i < n; i++)
        cin >> g[i];
    int pastn = n;
    while (n != (1 << __lg(n)))
        n++;

    f[0] = 1;
    CDQ(0, n);
    for (int i = 0; i < pastn; i++)
        cout << f[i] << " ";
    cout << endl;

    return 0;
}

深入探究

我们想想这玩意儿有什么用阿。

换句话来说,我们什么时候会用到这种形状的式子呢?

DP

在很多的“分段DP”中,如果长度相同的一段的贡献相同,整个式子就会长的跟个分治 FFT 一样。

这点应用范围较广,在我们提前算出了一段长度的贡献后,可以使用分治 FFT 进行实现,两只 \log 的时间复杂度大部分情况下是够用的。

然而他有个缺点,常数大,怎么解决?

  1. 将长度延伸到 2 的整数次幂,既好写,又减少了分讨情况,优化了常数。

  2. 若有多个形状类似的转移则在一个 CDQ 里同时计算,减去了递归的大常数问题。

  3. 每次计算多项式乘法时不用将两个数组都延伸到很长,由于一个长度是另一个的两倍,且由于只求后两倍长度的贡献,可以直接把前面的一倍长度砍掉,减少 FFT 时的时间开销。

  4. 每一次将两个数组都复制一遍进行 FFT 这样就不用将另一个数组还原回去,原本要进行四次变换就变成了三次。

将所有优化加起来可以优化掉将近十倍常数,跑的飞快,可以轻松过掉 3\times10^5 的数据。

但我们还是不知足,我们还要发挥他的最大作用。

首先我们想想既然分治 FFT 是基于 CDQ 的,能不能也像 CDQ 一样嵌套多层呢?

很遗憾的是,朕做不到。

究其原因是因为分治 FFT 中的一个值依赖于前面的值,若嵌套多维则需要做到与前面的值剥离关系。

那他就真的没用了吗?

回顾整个分治 FFT 的过程,他总是按照下标顺序依次计算出答案。

这给我们什么启发?

对于一个在线问题,依次询问每个位子的值 ,就是分治 FFT 在做的事情。

但是似乎我们可以加入一些修改?

假设我们依次询问值的同时,允许对值进行一些修改?

这样会影响到后续的值,确实是有效的。

经过思考之后发现,这做得到!

我们考虑在每次递归下去后每次看有没有修改,如果有,就在对应的一条递归链上改就行了!

单次修改 \log^2n !(因为要计算对后的贡献。

我们得寸进尺地想想能不能改转移的系数。

不行。

因为转移的系数会影响整个答案,光光一个一个改都爆 \Theta(n) 了。

所以说我们的分治 FFT 进化了!

这个时候我们发现他的式子形式像卷积。

我们便称他为半在线卷积

因为它可以支持小幅度修改。

练习题

CF848E Days of Floral Colours 题解 - LightningCreeper - 博客园