题解 P4717 【【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)】

· · 题解

以下是一个偏向于初学者的题解。

卷积的定义

对于一个序列,将其中元素一一映射到一个多项式函数的系数上,
这个多项式函数便叫做该序列的生成函数
形式化地讲,对于序列 f_0,f_1,\cdots,f_nf(x)=\displaystyle\sum_{k=0}^nf_kx^k 为其生成函数。
卷积即为生成函数的乘积在对应序列的变换上的的抽象,
“卷”即为其作用效果,“积”即为其本质。
对于序列 f,g,其卷积序列 f\otimes g 满足 (f\otimes g)_k=\displaystyle\sum_{i=0}^kf_i\times g_{k-i}=\sum_{i+j=k}f_i\times g_j
FFT计算的是循环卷积,也即 (f\otimes g)_k=\displaystyle\sum_{i+j\equiv k\pmod n}f_i\times g_jn 为序列长度。

卷积的基本性质

位运算卷积定义及其性质

一般的卷积满足 i+j=k,又称为加法卷积。

类似的,对于序列 f=\{f_0,f_1,\cdots,f_{2^n-1}\},g=\{g_0,g_1,\cdots,g_{2^n-1}\}
可以定义位运算卷积 \displaystyle(f\otimes_\odot g)_k=\sum_{i\odot j=k}f_i\times g_j,其中 \odot=\text{and},\text{or},\text{xor}
还有 \max 卷积,这里不再拓展, 读者自行查找相关资料,现在着重讨论 \otimes_\text{xor}

FFT 的性质

在 FFT 的实现过程中,我们得到了两个很重要的处理方法: 奇偶分段和蝴蝶变换。
现在让我们仔细审视一下 FFT 的变换过程。
分治证明式:

f(w^k_n)=f_0(w^k_\frac{n}{2})+w^k_n\times f_1(w^k_\frac{n}{2}) f(w^{k+\frac{n}{2}}_n)=f_0(w^k_\frac{n}{2})-w^k_n\times f_1(w^k_\frac{n}{2})

实际变换过程:

\begin{matrix} f_0(w^k_\frac{n}{2}) & f_1(w^k_\frac{n}{2}) \\ \Downarrow & \Downarrow \\ f_0(w^k_\frac{n}{2})+w^k_n\times f_1(w^k_\frac{n}{2}) & f_0(w^k_\frac{n}{2})-w^k_n\times f_1(w^k_\frac{n}{2}) \end{matrix}

注意到单次卷积 DFT\rightarrow点乘\rightarrowIDFT 实际上是一个单次递归过程。
( DFT 时向下递归,点乘时到达底部,IDFT 时向上回溯)
那么 IDFT 就是:

\begin{matrix} f_0(w^{-k}_\frac{n}{2}) & f_1(w^{-k}_\frac{n}{2}) \\ \Downarrow & \Downarrow \\ f_0(w^{-k}_\frac{n}{2})+w^{-k}_n\times f_1(w^{-k}_\frac{n}{2}) & f_0(w^{-k}_\frac{n}{2})-w^{-k}_n\times f_1(w^{-k}_\frac{n}{2}) \end{matrix}

F(f) 表示对序列 f 作DFT以后的序列,

我们已经知道了以下性质: $$F^{-1}(F(f))=F(F^{-1}(f))=f$$ $$F(f\otimes g)_k=F(f)_k\times F(g)_k$$ $$F(f\oplus g)=F(f)\oplus F(g)$$ ## FWT的推导过程 我们希望通过与 FFT 类似的方法实现位运算卷积。 设 $F_\text{xor}(f)$ 表示 FWT 的异或形式,则其至少满足: $f\otimes_\text{xor}g=F^{-1}_\text{xor}(F_\text{xor}(f)\bullet F_\text{xor}(g))

其中 \bullet 表示对应位置相乘。

已知 c=a\otimes_\text{xor}b,设三个序列奇偶分段后分别为 c_0,c_1,a_0,a_1,b_0,b_1,则显然有

c_0=(a_0\otimes_\text{xor}b_0)+(a_1\otimes_\text{xor}b_1) c_1=(a_1\otimes_\text{xor}b_0)+(a_0\otimes_\text{xor}b_1)

其中 + 即向量加法,表示对应位置相加。
考虑蝴蝶变换,发现

(a_0+a_1)\otimes_\text{xor}(b_0+b_1)=(a_0\otimes_\text{xor}b_0)+(a_1\otimes_\text{xor}b_1)+(a_0\otimes_\text{xor}b_1)+(a_1\otimes_\text{xor}b_0) (a_0-a_1)\otimes_\text{xor}(b_0-b_1)=(a_0\otimes_\text{xor}b_0)+(a_1\otimes_\text{xor}b_1)-(a_0\otimes_\text{xor}b_1)-(a_1\otimes_\text{xor}b_0)

其中 - 即向量减法,表示对应位置相减。(+,- \in\oplus
显然,逆变换时只要对如上两式加减消元即可。
于是可得 F_\text{xor} 为:

\begin{matrix} a_0 & a_1 & b_0 & b_1\\ \Downarrow & \Downarrow & \Downarrow & \Downarrow \\ a_0+ a_1 & a_0-a_1 & b_0+ b_1 & b_0-b_1 \end{matrix}

向下递归,计算出卷积

d_0=(a_0+a_1)\otimes_\text{xor}(b_0+b_1) d_1=(a_0-a_1)\otimes_\text{xor}(b_0-b_1)

回溯时,逆变换为

\begin{matrix} \dfrac{1}{2}(d_0+d_1) & \dfrac{1}{2}(d_0-d_1) \\ \Uparrow & \Uparrow \\ d_0 & d_1 \end{matrix}

注意到系数 \dfrac{1}{2} 可以提出来,到最后再除,因此可得到: F^{-1}_\text{xor}(f)_k=2^{-n}F_\text{xor}(f)_k
类似的,可以得到 \operatorname{and},\operatorname{or} 的位运算卷积,这里读者自行推导,只给出结果:

\begin{matrix} & a_0 & a_1 & & a_0-a_1 & a_1\\ F_\text{and}: & \Downarrow & \Downarrow & F^{-1}_\text{and}: & \Uparrow & \Uparrow \\ & a_0+ a_1 & a_1 & & a_0 & a_1 \end{matrix} \begin{matrix} & a_0 & a_1 & & a_0 & a_1-a_0\\ F_\text{or}: & \Downarrow & \Downarrow & F^{-1}_\text{or}: & \Uparrow & \Uparrow \\ & a_0 & a_1+ a_0 & & a_0 & a_1 \end{matrix}

自此,我们已经可以 \Theta(n\log n) 实现单次位运算卷积,即:

f\otimes_\text{xor}g=F^{-1}_\text{xor}(F_\text{xor}(f)\bullet F_\text{xor}(g))

FWT 的性质

为了方便使用,我们希望 FWT 满足与 FFT 类似的性质。

总结与代码

经过严谨的推导,我们得出了 FWT 的实现,
并发现其具有与 FFT 类似的性质,这为我们解题带来了极大的方便。

#include<stdio.h>
#include<string.h>
typedef unsigned char byte;
typedef unsigned long long ull;
typedef long long ll;
typedef unsigned int word;
struct READ{
    char c;
    inline READ(){c=getchar();}
    template<typename type>
    inline READ& operator>>(register type& num){
        while('0'>c||c>'9') c=getchar();
        for(num=0;'0'<=c&&c<='9';c=getchar())
            num=num*10+(c-'0');
        return *this;
    }
};//快读快写 
#define mx 17
#define mx_ 16
word root[1<<mx],inv[1<<mx],realid[1<<mx];
const word mod=998244353;
inline ull pow(register ull a,register word b){
    register ull ans=1;
    for(;b;b>>=1){
        if(b&1) (ans*=a)%=mod;
        (a*=a)%=mod;
    }
    return ans;
}//快速幂 
#define loading()           \
    register ull num1=pow(3,(mod-1)>>mx);   \
    register ull num2=pow(num1,mod-2);  \
    register word head,i;   \
    register byte floor;    \
    register ull ni2=pow(2,mod-2);  \
    root[0]=inv[0]=1;   \
    for(i=1;head=i,i<(1<<mx);++i){  \
        root[i]=num1*root[i-1]%mod; \
        inv[i]=num2*inv[i-1]%mod;   \
        for(floor=0;floor<mx;++floor,head>>=1)  \
            realid[i]=realid[i]<<1|(head&1);    \
    }
//单位根及下标翻转初始化 
#define load()  \
    register ull num1,num2;\
    register word head,i;   \
    register byte floor
//floor 变换区间大小
//head 变换区间头指针 
//i 变换位置 
#define nttfor(size)    \
    for(floor=0;floor<(size);++floor)   \
        for(head=0;head<(1u<<(size));head+=1u<<(floor+1))   \
            for(i=0;i<(1u<<(floor));++i)
//FFT/FWT循环 
#define ntt(num,root)(  \
    num1=(num)[head+i], \
    num2=(num)[head+i+(1<<floor)],  \
    (num)[head+i]=(num1+((num2*=root[i<<(mx_-floor)])%=mod))%mod,   \
    (num)[head+i+(1u<<floor)]=(num1+mod-num2)%mod)
//FFT蝴蝶变换 
#define _and(num)(  \
    num1=(num)[head+i], \
    (num)[head+i]=(num1+(num)[head+i+(1u<<floor)])%mod)
//FWT and 蝴蝶变换 
#define and_(num)(  \
    num1=(num)[head+i], \
    (num)[head+i]=(num1+mod-(num)[head+i+(1u<<floor)])%mod)
//逆FWT and 蝴蝶变换 
#define _or(num)(   \
    num2=(num)[head+i+(1<<floor)],  \
    (num)[head+i+(1u<<floor)]=(num2+(num)[head+i])%mod)
//FWT or 蝴蝶变换 
#define or_(num)(   \
    num2=(num)[head+i+(1<<floor)],  \
    (num)[head+i+(1u<<floor)]=(num2+mod-(num)[head+i])%mod)
//逆FWT and 蝴蝶变换 
#define _xor(num)(  \
        num1=(num)[head+i], \
        num2=(num)[head+i+(1<<floor)],  \
        (num)[head+i]=(num1+mod-num2)%mod,  \
        (num)[head+i+(1<<floor)]=(num1+num2)%mod)
//FWT xor 蝴蝶变换 
#define id(size,i) (realid[i]>>(mx-(size)))
//下标二进制翻转的编号 
#define FOR(size) for(i=0;i<1u<<(size);i++)
word eax[1<<mx],ebx[1<<mx],ecx[1<<mx],edx[1<<mx],eex[1<<mx],efx[1<<mx];
int main(){
    byte size;
    register READ cin;
    cin>>size;
    loading();
    FOR(size) cin>>eax[id(size,i)];
    FOR(size) cin>>ebx[id(size,i)];
    memcpy(ecx,eax,4u<<size);
    memcpy(edx,ebx,4u<<size);
    memcpy(eex,eax,4u<<size);
    memcpy(efx,ebx,4u<<size);
    nttfor(size){
        _or(eax),_or(ebx);
        _and(ecx),_and(edx);
    }
    FOR(size){
        head=id(size,i);
        root[head]=(ull)(eax[i])*ebx[i]%mod;
        inv[head]=(ull)(ecx[i])*edx[i]%mod;
    }
    nttfor(size) or_(root),and_(inv);
    FOR(size) printf("%u ",root[i]);
    putchar('\n');
    FOR(size) printf("%u ",inv[i]);
    putchar('\n');
    nttfor(size) _xor(eex),_xor(efx);
    FOR(size) root[id(size,i)]=(ull)(eex[i])*efx[i]%mod;
    nttfor(size) _xor(root);
    num1=pow(ni2,size);
    FOR(size) printf("%u ",num1*root[i]%mod);
    // $F_\text{xor}^{-1}(f)_k=2^{-n}F_\text{xor}(F)_k$
    return 0;
}