一类用转置原理优化算法的 Trick
更好的阅读体验
注:本文前半部分大致描述了转置原理,中间为 CF2039F1/F2 的题解,最后一部分是其在 NTT 上的应用。
对于一些算法,如果我们可以将其视作对一个向量的若干线性变换,则称这种算法为线性算法。
具体地,我们可以将算法过程中维护的变量写成一个向量
把线性变换写作一个矩阵
我们称
但是我们显然不能把矩阵
首先把矩阵
这样就相当于把所有初等变换的顺序倒过来然后把每个初等变换转置一下,于是我们只需要求每个初等变换的转置就好了:
考虑一个初等矩阵
- 对于初等变换交换矩阵的某两行
i, j :M_{j, i} = M_{i, j} = 1, M_{i, i} = M_{j, j} = 0 ,转置后有M^T_{j, i} = M^T_{i, j} = 1, M^T_{i, i} = M^T_{j, j} = 0 ,因此这种初等变换的转置和原本一样,对应到向量上的操作就是交换两维。 - 对于初等变换给第
i 行乘上k :M_{i, i} = k ,转置后有M^T_{i, i} = k ,因此这种初等变换的转置和原本一样,对应到向量上的操作就是某一维乘上一个系数k 。 - 对于初等变换给第
i 行加上第j 行:M_{i, j} = 1 ,转置后有M^T_{j, i} = 1 ,对应到向量上的操作就是给第i 维加上第j 维变成给第j 维变成第i 维。
于是我们可以得到线性算法所有运算均为以下三种之一:
- 交换变量
x, y ,转置依然是交换变量x, y 。 -
-
显然一个算法转置后复杂度不变,如果我们能对转置后的算法进行优化,那优化后再转置一次就可以得到原算法的优化了。
另一方面,在某些题目中可以将转置后的算法进行一些小修改使其达到与原算法相同的效果。
例题
CF2039F1 Shohag Loves Counting (Easy Version) / CF2039F2 Shohag Loves Counting (Hard Version)
先考虑 F1,考虑什么样的序列是好的,显然每次长度增加
然后考虑区间
因此问题转化成:求所有单调递减序列
设
转移式子大概长这样:
然后发现系数是
考虑优化转移,显然可以改成容斥,令
后面减去的是求
直接做是
参考代码:
// 長い夜の終わりを信じながら
// Think twice, code once.
#include <vector>
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>
#define eputchar(c) putc(c, stderr)
#define eprintf(...) fprintf(stderr, __VA_ARGS__)
#define eputs(str) fputs(str, stderr), putc('\n', stderr)
using namespace std;
const int mod = 998244353;
int T, n, mu[1000005], dp[1000005], sum[1000005], tmp[1000005];
vector<int> vec[1000005];
int main() {
mu[1] = 1;
for (int i = 1; i <= 1000000; i++)
for (int j = i; j <= 1000000; j += i) {
vec[j].push_back(i);
if (j != i) mu[j] -= mu[i];
}
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
for (int i = 1; i <= n; i++) dp[i] = sum[i] = 0;
for (int i = n; i >= 1; i--) {
for (int j : vec[i]) tmp[j] = (mod - dp[j] * 2 % mod) % mod;
for (int j : vec[i])
for (int k : vec[j]) tmp[k] = (tmp[k] + 2ll * sum[j] * mu[j / k] % mod + mod) % mod;
tmp[i] = (tmp[i] + 1) % mod;
for (int j : vec[i]) {
dp[j] = (dp[j] + tmp[j]) % mod;
for (int k : vec[j]) sum[k] = (sum[k] + tmp[j]) % mod;
}
}
printf("%d\n", sum[1]);
}
return 0;
}
由于我们是从后往前枚举的,因此无法通过 F2。
但是这显然是线性算法,我们可以将其视为有一个向量
初始时
然后枚举的每个
将其转置后即为:
也就是初始时令
由于是从前往后做的,因此后一个的答案可以从前一个继承过来,只需要做一遍就可以求出所有
参考代码:
// 長い夜の終わりを信じながら
// Think twice, code once.
#include <vector>
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>
#define eputchar(c) putc(c, stderr)
#define eprintf(...) fprintf(stderr, __VA_ARGS__)
#define eputs(str) fputs(str, stderr), putc('\n', stderr)
using namespace std;
const int mod = 998244353;
int T, n, mu[1000005], dp[1000005], sum[1000005], tmp[1000005], val, ans[1000005];
vector<int> vec[1000005];
int main() {
mu[1] = 1;
for (int i = 1; i <= 1000000; i++)
for (int j = i; j <= 1000000; j += i) {
vec[j].push_back(i);
if (j != i) mu[j] -= mu[i];
}
sum[1] = 1;
for (int i = 1; i <= 1000000; i++) {
for (int j : vec[i]) {
tmp[j] = (tmp[j] + dp[j]) % mod;
for (int k : vec[j]) tmp[j] = (tmp[j] + sum[k]) % mod;
}
val = (val + tmp[i]) % mod;
for (int j : vec[i])
for (int k : vec[j]) sum[j] = (sum[j] + 2ll * tmp[k] * mu[j / k] % mod + mod) % mod;
for (int j : vec[i]) dp[j] = (dp[j] - tmp[j] * 2 % mod) % mod;
for (int j : vec[i]) tmp[j] = 0;
ans[i] = val;
}
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
printf("%d\n", ans[n]);
}
return 0;
}
优化 NTT
先贴个码:
static void NTT(poly &g, int flag) {
int n = g.size();
vector<unsigned long long> f(g.begin(), g.end());
vector<int> swp(n);
for (int i = 0; i < n; i++) {
swp[i] = swp[i >> 1] >> 1 | ((i & 1) * (n >> 1));
if (i < swp[i]) std::swap(f[i], f[swp[i]]);
}
for (int mid = 1; mid < n; mid <<= 1) {
int w1 = power(flag ? G : invG, (mod - 1) / mid / 2);
vector<int> w(mid);
w[0] = 1;
for (int i = 1; i < mid; i++) w[i] = (long long)w[i - 1] * w1 % mod;
for (int i = 0; i < n; i += mid << 1)
for (int j = 0; j < mid; j++) {
int t = (long long)w[j] * f[i + mid + j] % mod;
f[i + mid + j] = f[i + j] - t + mod;
f[i + j] += t;
}
if (mid == 1 << 10)
for (int i = 0; i < n; i++) f[i] %= mod;
}
int inv = flag ? 1 : power(n, mod - 2);
for (int i = 0; i < n; i++) g[i] = f[i] % mod * inv % mod;
return;
}
显然 NTT 也是线性算法,考虑它的转置长什么样。
首先翻转顺序,然后考虑中间这部分:
int t = (long long)w[j] * f[i + mid + j] % mod;
f[i + mid + j] = f[i + j] - t + mod;
f[i + j] += t;
以下用
转置后就变成:
于是我们可以写出转置后的 NTT:
static void NTT(poly &g, int flag) {
int n = g.size();
vector<unsigned long long> f(g.begin(), g.end());
int inv = flag ? 1 : power(n, mod - 2);
for (int i = 0; i < n; i++) f[i] = f[i] % mod * inv % mod;
for (int mid = n >> 1; mid >= 1; mid >>= 1) {
int w1 = power(flag ? G : invG, (mod - 1) / mid / 2);
vector<int> w(mid);
w[0] = 1;
for (int i = 1; i < mid; i++) w[i] = (long long)w[i - 1] * w1 % mod;
for (int i = 0; i < n; i += mid << 1)
for (int j = 0; j < mid; j++) {
int t = (long long)(f[i + j] - f[i + mid + j] % mod + mod) * w[j] % mod;
f[i + j] += f[i + mid + j];
f[i + mid + j] = t;
}
if (mid == 1 << 10)
for (int i = 0; i < n; i++) f[i] %= mod;
}
vector<int> swp(n);
for (int i = 0; i < n; i++) {
swp[i] = swp[i >> 1] >> 1 | ((i & 1) * (n >> 1));
if (i < swp[i]) std::swap(f[i], f[swp[i]]);
}
for (int i = 0; i < n; i++) g[i] = f[i] % mod;
return;
}
然后考虑 NTT 的本质是什么,本质是求多项式在
和:
所以 NTT 的转置和 NTT 效果完全相同!
对于很多多项式操作(如多项式乘法、求逆等),在 DFT 之后 IDFT 之前,我们并不关心其每个值的顺序,因此如果只对 DFT 转置,那 DFT 和 IDFT 之间的两次蝴蝶变换就可以被省略掉了!
// 这份代码还稍微优化了下取模常数。
static void NTT(poly &g, int flag) {
int n = g.size();
vector<int> f(g.begin(), g.end());
if (flag) {
for (int mid = n >> 1; mid >= 1; mid >>= 1) {
int w1 = power(G, (mod - 1) / mid / 2);
vector<int> w(mid);
w[0] = 1;
for (int i = 1; i < mid; i++) w[i] = (long long)w[i - 1] * w1 % mod;
for (int i = 0; i < n; i += mid << 1)
for (int j = 0; j < mid; j++) {
int t = (long long)(f[i + j] - f[i + mid + j] + mod) * w[j] % mod;
f[i + j] = f[i + j] + f[i + mid + j] >= mod ?
f[i + j] + f[i + mid + j] - mod : f[i + j] + f[i + mid + j];
f[i + mid + j] = t;
}
}
for (int i = 0; i < n; i++) g[i] = f[i];
} else {
for (int mid = 1; mid < n; mid <<= 1) {
int w1 = power(invG, (mod - 1) / mid / 2);
vector<int> w(mid);
w[0] = 1;
for (int i = 1; i < mid; i++) w[i] = (long long)w[i - 1] * w1 % mod;
for (int i = 0; i < n; i += mid << 1)
for (int j = 0; j < mid; j++) {
int t = (long long)w[j] * f[i + mid + j] % mod;
f[i + mid + j] = f[i + j] - t < 0 ? f[i + j] - t + mod : f[i + j] - t;
f[i + j] = f[i + j] + t >= mod ? f[i + j] + t - mod : f[i + j] + t;
}
}
int inv = power(n, mod - 2);
for (int i = 0; i < n; i++) g[i] = (long long)f[i] * inv % mod;
}
return;
}