题解 P4240 【毒瘤之神的考验】

· · 题解

注:集合\mathbb{P}为质数集。

P4240 毒瘤之神的考验解题报告:

更好的阅读体验

题意

分析

莫比乌斯反演结合根号分治好题。

首先,欧拉函数有一个性质:

\varphi(n\cdot m)\cdot\varphi(\gcd(n,m))=\varphi(n)\cdot\varphi(m)\cdot\gcd(n,m)

证明如下:

由欧拉函数的公式有:

\varphi(n\cdot m)\cdot\varphi(\gcd(n,m)) =n\cdot m\cdot\prod_{p\mid n\cdot m,p\in\mathbb{P}}\frac{p-1}{p}\cdot\gcd(n,m)\cdot\prod_{q\mid\gcd(n,m),q\in\mathbb{P}}\frac{q-1}{q}=n\cdot m\cdot\prod_{p\mid lcm(n,m),p\in\mathbb{P}}\frac{p-1}{p}\cdot\prod_{q\mid\gcd(n,m),q\in\mathbb{P}}\frac{q-1}{q}\cdot\gcd(n,m) =n\cdot m\cdot\prod_{p\mid n\ and\ p\mid m,p\in\mathbb{P}}\frac{p-1}{p}\cdot\prod_{q\mid n\ or\ q\mid m,q\in\mathbb{P}}\frac{q-1}{q}\cdot\gcd(n,m)

我们根据容斥的思想可以得到:

\varphi(n\cdot m)\cdot\varphi(\gcd(n,m))=n\cdot\prod_{p\mid n,p\in\mathbb{P}}\frac{p-1}{p}\cdot m\cdot\prod_{q\mid m,q\in\mathbb{P}}\frac{q-1}{q}\cdot\gcd(n,m) =\varphi(n)\cdot\varphi(m)\cdot\gcd(n,m)

n\geqslant m,然后我们就可以带入最开始的式子推一下:

\sum_{i=1}^n\sum_{j=1}^m\varphi(i\cdot j)=\sum_{i=1}^n\sum_{j=1}^m\frac{\varphi(i)\cdot\varphi(j)\cdot\gcd(i,j)}{\varphi(\gcd(i,j))}

然后枚举\gcd(不妨设n\leqslant m

\sum_{i=1}^n\sum_{j=1}^m\varphi(i\cdot j)=\sum_{d=1}^n\sum_{i=1}^n\sum_{j=1}^m[\gcd(i,j)==d]\frac{\varphi(i)\cdot\varphi(j)\cdot d}{\varphi(d)} =\sum_{d=1}^n\sum_{i=1}^{\lfloor\frac{n}{d}\rfloor}\sum_{j=1}^{\lfloor\frac{m}{d}\rfloor}[\gcd(i,j)==1]\frac{\varphi(i\cdot d)\cdot\varphi(j\cdot d)\cdot d}{\varphi(d)} =\sum_{d=1}^n\sum_{i=1}^{\lfloor\frac{n}{d}\rfloor}\sum_{j=1}^{\lfloor\frac{m}{d}\rfloor}\sum_{k\mid\gcd(i,j)}\mu(k)\cdot\frac{\varphi(i\cdot d)\cdot\varphi(j\cdot d)\cdot d}{\varphi(d)} =\sum_{d=1}^n\sum_{k=1}^{\lfloor\frac{n}{d}\rfloor}\sum_{i=1}^{\lfloor\frac{n}{d\cdot k}\rfloor}\sum_{j=1}^{\lfloor\frac{m}{d\cdot k}\rfloor}\mu(k)\cdot\frac{\varphi(i\cdot d\cdot k)\cdot\varphi(j\cdot d\cdot k)\cdot d}{\varphi(d)} =\sum_{t=1}^n\sum_{d\mid t}\frac{d}{\varphi(d)}\cdot\mu(\frac{t}{d})\cdot\sum_{i=1}^{\lfloor\frac{n}{t}\rfloor}\varphi(i\cdot t)\cdot\sum_{j=1}^{\lfloor\frac{m}{t}\rfloor}\varphi(j\cdot t)

这个式子并不好处理,可以先设f(x)=\sum_{d\mid x}\frac{d}{\varphi(d)}\cdot\mu(\frac{x}{d})g(x,y)=\sum_{i=1}^y\varphi(i\cdot x),然后原式化为:

\sum_{t=1}^n f(t)\cdot g(t,\lfloor\frac{n}{t}\rfloor)\cdot g(t,\lfloor\frac{m}{t}\rfloor) 关于$g(x,y)$,我们可以得到一个很显然的递推式:$g(x,y)=g(x,y-1)+\varphi(x\cdot y)$。 不妨再设$S(x,y,z)=\sum_{t=1}^z f(t)\cdot g(t,x)\cdot g(t,y)$,同样$S(x,y,z)=S(x,y,z-1)+f(t)\cdot g(t,x)\cdot g(t,y)

但是xy的大小均可以到达n,因此预处理g是不可行的,那么预处理S更不行了。

考虑根号分治,我们设一个阈值t,那么把答案分成两部分考虑:

总复杂度为O(n\log n+n\cdot t^2+T\cdot(\sqrt{n}+\lfloor\frac{n}{t}\rfloor))t50左右可过。

代码

#include<stdio.h>
#include<vector>
#define int long long
using namespace std;
const int maxn=100005,mod=998244353,maxt=55;
int T,n,m,cnt,t;
int a[maxn],p[maxn],miu[maxn],phi[maxn],nphi[maxn],f[maxn];
vector<int>g[maxn],S[maxt][maxt];
int ksm(int a,int b){
    int res=1;
    while(b){
        if(b&1)
            res=res*a%mod;
        a=a*a%mod,b>>=1;
    }
    return res;
}
void init(){
    t=50;
    p[1]=miu[1]=phi[1]=1;
    for(int i=2;i<maxn;i++){
        if(p[i]==0)
            a[++cnt]=i,miu[i]=-1,phi[i]=i-1;
        for(int j=1;j<=cnt;j++){
            if(i*a[j]>=maxn)
                break;
            p[i*a[j]]=1;
            if(i%a[j]==0){
                miu[i*a[j]]=0;
                phi[i*a[j]]=phi[i]*a[j];
                break;
            }
            miu[i*a[j]]=-miu[i];
            phi[i*a[j]]=phi[i]*(a[j]-1);
        }
    }
    for(int i=1;i<maxn;i++)
        nphi[i]=ksm(phi[i],mod-2);
    for(int i=1;i<maxn;i++)
        for(int j=1;i*j<maxn;j++)
            f[i*j]=(f[i*j]+i*nphi[i]%mod*miu[j])%mod;
    for(int i=1;i<maxn;i++){
        g[i].push_back(0);
        for(int j=1;i*j<maxn;j++)
            g[i].push_back((g[i][j-1]+phi[i*j])%mod);
    }
    for(int i=1;i<=t;i++)
        for(int j=1;j<=t;j++){
            S[i][j].push_back(0);
            for(int k=1;k*max(i,j)<maxn;k++)
                S[i][j].push_back((S[i][j][k-1]+f[k]*g[k][i]%mod*g[k][j]%mod)%mod);
        }
}
int solve(int n,int m){
    int res=0,l,r;
    for(int i=1;i<=max(n,m)/t;i++)
        res=(res+f[i]*g[i][n/i]%mod*g[i][m/i]%mod)%mod;
    l=max(n,m)/t+1;
    while(l<=min(n,m)){
        r=min(n/(n/l),m/(m/l));
        res=(res+(S[n/l][m/l][r]-S[n/l][m/l][l-1]+mod)%mod)%mod;
        l=r+1;
    }
    return res;
}
signed main(){
    init();
    scanf("%lld",&T);
    while(T--){
        scanf("%lld%lld",&n,&m);
        printf("%lld\n",solve(n,m));
    }
    return 0;
}