题解 P3768 【简单的数学题】

Soulist

2019-03-01 21:54:20

题解

题目要我们求一个很毒瘤的东东:

\sum_{i=1}^n\sum_{j=1}^n gcd(i,j)*ij

我们可以很套路地去化简这个式子:(x为枚举的gcd

\sum_{x=1}^nx^3\sum_{i=1}^{n/x}\sum_{j = 1}^{n/x}ij{[gcd(i,j)==1]}

也就是:

\sum_{x=1}^nx^3\sum_{i=1}^{n/x}\sum_{j = 1}^{n/x}ij\sum_{d|gcd(i,j)}\mu(d)

一般地,我们把\mu放到前面去。

\sum_{x=1}^nx^3\sum_{d=1}^{n/x}d^2\mu(d)*\sum_{i=1}^{n/(x*d)}\sum_{j=1}^{n/(x*d)}ij

后面那一坨其实只和(n/(x*d))有关,我们可以利用一点数学知识化简:

g(x) = \dfrac{x^2(x+ 1)^2}{4}x

所以原来的式子其实就是:

\sum_{x=1}^nx^3\sum_{d=1}^{n/x}d^2\mu(d)*g(\left\lfloor\dfrac{n}{xd}\right\rfloor)

考虑套路:记T = xd

所以原来的式子就是:

\sum_{T=1}^n g(\left\lfloor\dfrac{n}{xd}\right\rfloor)*(xd)^2\sum_{x|T}\mu(d)*(x)\quad(d = T/x)

其实也就是:

\sum_{T=1}^n g(\left\lfloor\dfrac{n}{T}\right\rfloor)T^2\sum_{x|T}\mu(x)*(T/x)

接下来我们考虑两个积性函数的卷积:

这里需要用到的是这两个: **这里给出了一些积性函数的卷积是什么的证明,可跳过** $(\phi*I)(x) = Id(x) (\mu*Id)(x) = \phi(x)

其中:\phi为欧拉函数,\mu为莫比乌斯函数,I(x) = 1 (x为任意数)Id(x) = x(后面这两个都是完全积性函数)

简单证明一下:

(\phi*I)(x) = Id(x)

其实就是:

\sum_{i|x} \phi(i) * I(x/i)$, 其中$I(x/i) = 1

也就是:

那么我们不妨假设$n$的因子依次为:$1,p1,p2.......n(p_i$不一定为质数!$)

接着我们考虑对于gcd(n, a) = 1, 其中,满足这个式子的a的个数就是\phi(n)

考虑gcd(n, a) = d的时候,那么d|nd|a,所以其实这个式子就是: gcd(n/d, a/d) = 1 ,满足这个式子的a的个数为:\phi(n/d)个。

不难发现:\sum_{i|x}\phi(i) = \sum_{i|x} \phi(x/i) 而这个式子的答案的意义为:与xgcd依次为1,p_1...p_n的数的个数,这个数的个数为x

所以:\sum_{i|x}\phi(i) = x = Id(x)

接下来我们证明这个:(\mu*Id)(x) = \phi(x)

不难发现,卷积其实具有交换律。(f*g*h)(x) = (f*h*g)(x)

所以:\mu*Id*I = \mu*I*Id = \phi*I

即:\mu*Id*I = \phi* I,故:\mu*Id = \phi

完毕QWQ

回到原来的式子:

\sum_{T=1}^n g(\left\lfloor\dfrac{n}{T}\right\rfloor)T^2\sum_{x|T}\mu(x)*(T/x)

其实就是:

\sum_{T=1}^n g(\left\lfloor\dfrac{n}{T}\right\rfloor)T^2\sum_{x|T}\mu(x)*Id(T/x)

后面那一部分其实就是:(\mu*Id)(T) = (\phi)(T)

所以我们要求的就是:

\sum_{T=1}^n g(\left\lfloor\dfrac{n}{T}\right\rfloor)T^2\phi(T)

考虑记f(x) = x^2\phi(x)

然后我们不难发现这是一个积性函数。

所以我们要求的就是:

\sum_{T=1}^n g(\left\lfloor\dfrac{n}{T}\right\rfloor)f(T)

这一部分可以整除分块,而后面的部分,因为f(x)为积性函数,所以我们可以使用杜教筛。

假定存在h(x)满足(h*f)(x)可以较快的算出来

那么这两个的卷积就是:

\sum_{i|x}f(i)*h(x/i)

我们把f(i)代回去,就是:

\sum_{i|x} i^2\phi(i)*h(x/i)

这个里面最麻烦的是什么呢?是在\phi(i)前面的i^2,这个东西有什么办法被我们去掉呢?

是这个:i*(x/i) = x

所以:i^2 * (x/i)^2 = x^2

故我们可以令h(x) = x^2

于是我们所求即:\sum_{i|x} i^2\phi(i)*(x/i)^2

糟糕,是心动的感觉:i^2*(x/i)^2 = x^2!

换下顺序,再利用一下乘法分配律,这个式子就是:x^2\sum_{i|x}\phi(i)

但是后面这个,我们给它乘上一个1

\sum_{i|x}\phi(i) = \sum_{i|x}\phi(i) * 1 = \sum_{i|x}\phi(i) * I(x/i) = (\phi*I)(x) = Id(x)

(最后一步是之前证明的:(\phi*I = Id)

所以我们后面的这一坨也就是:x

于是:

\sum_{i|x}f(i)*h(x/i)

就是:(其中h(x) = x^2

x^3

然后我们把杜教筛的式子套上去,

S(n) = \sum_{i=1}^nf(i)

所以:

h(1)S(n) = \sum_{i=1}^nh(i)S(\left\lfloor\dfrac{n}{i}\right\rfloor) - \sum_{i=2}^nh(i)S(\left\lfloor\dfrac{n}{i}\right\rfloor)

因为h(1) = 1^2 = 1,所以:

S(n) = \sum_{i=1}^nh(i)S(\left\lfloor\dfrac{n}{i}\right\rfloor) - \sum_{i=2}^nh(i)S(\left\lfloor\dfrac{n}{i}\right\rfloor)

而前面这一坨:

$\sum_{i=1}^n(h*f)(i) = \sum_{i=1}^n\sum_{d|i}h(d)f(i/d) =\sum_{d=1}^nh(d)*\sum_{i=1}^{n/d}f(i) = \sum_{d=1}^nh(d)*S(\left\lfloor\dfrac{n}{d}\right\rfloor)

然而,观察第一个式子,其就是:(h*f)这个函数的前缀和

然而当h(x) = x^2时,(h*f) = x^3

那么\sum_{i=1}^n(h*f)(i) = \sum_{i=1}^ni^3

这个东西怎么算?

1^3 + 2^3+3^3+4^3+....... + n^3

相信大家其实上过小学奥数QAQ,=\dfrac{n^2(n+1)^2}{4}

所以原式:

S(n) = \dfrac{n^2(n+1)^2}{4} - \sum_{i=2}^nh(i)S(\left\lfloor\dfrac{n}{i}\right\rfloor)

然后后面这个仍然可以整除分块求,而且再利用递归去做。

但是,整除分块的过程注意到:

假设l-r\left\lfloor\dfrac{n}{l}\right\rfloor都相同,我们可以利用乘法分配律:

即:S(\left\lfloor\dfrac{n}{l}\right\rfloor) * \sum_{i=l}^rh(i)

因为h(i) = i^2

所以我们要求l-r这一段区间的平方和。

可以利用:1^2 + 2^2 + 3^2+......+n^2

= \dfrac{n*(n+1)*(2n+1)}{6}

可以记sum(n) = \dfrac{n*(n+1)*(2n+1)}{6}

(小学数学)

再做一个减法,所以\sum_{i=l}^ri^2 = sum(r) - sum(l - 1)

所以最后我们得到了一个比较优秀的式子:

S(n) = \dfrac{n^2(n+1)^2}{4} - \sum_{i=2}^nS(\left\lfloor\dfrac{n}{i}\right\rfloor)h(i)

利用整除分块,对h(l)h(r)的前缀和利用前面所讲O(1)求出,然后递归地去得到S(n/i)的值,当然,我们还需要提前求出较多的S(1)-S(m)的值,至于这个值需要利用欧拉筛求欧拉函数来处理。

毕竟:S(x) = \sum_{i=1}^xi^2\phi(i)

所以不难得到递推:S(x) = S(x - 1) + x^2\phi(x)

根据实测,大概m500W左右的时候速度比较快,当然300W-1000W的时候也都是能过的!

写这篇题解算是帮自己同时复习杜教筛和莫比乌斯反演吧QAQ(记得算h(x)的前缀和利用到的式子需要除以6,那么我们就需要乘上6inv,以及这道题恶心的一点就是到处都需要取模...........然后要开long long

#include<bits/stdc++.h>
using namespace std;

#define int unsigned long long

int read() {
    char cc = getchar(); int cn = 0, flus = 1;
    while(cc < '0' || cc > '9') {  if( cc == '-' ) flus = -flus;  cc = getchar();  }
    while(cc >= '0' && cc <= '9')  cn = cn * 10 + cc - '0', cc = getchar();
    return cn * flus;
}

const int N = 5e6 + 5;
#define inf 9999999
#define g(x) (((( (x % mod) * ((x + 1) % mod) / 2 ) % mod) * ( (x % mod) * ((x + 1) % mod) / 2 % mod)) % mod)
#define g2(x) (( (x % mod) * ((x + 1) % mod) % mod * ((2 * x + 1) % mod) % mod * inv6 % mod) % mod)

int n, phi[N], pr[N], f[N], top, mod, inv6, inv2;
bool isp[N];

map<int, int> map1;
void init() {
    phi[1] = f[1] = 1;
    for ( int i = 2; i <= n; ++ i ) {
        if( !isp[i] )  pr[++top] = i, phi[i] = i - 1;
        for ( int j = 1; j <= top; ++ j ) {
            if( i * pr[j] > n ) break;
            isp[i * pr[j]] = 1;
            if( i % pr[j] == 0 ) {
                phi[i * pr[j]] = phi[i] * pr[j];
                break;
            }
            phi[i * pr[j]] = phi[i] * phi[pr[j]];
        }
        f[i] = f[i - 1] + ((1ll * phi[i] * i) % mod * i) % mod; f[i] %= mod;
    }
} 

int fpow( int x, int k ) {
    int base = x, ans = 1;
    while( k ) {
        if( k & 1 ) ans *= base, ans %= mod;
        base *= base, base %= mod;
        k >>= 1;
    }
    return ans;
}

int phi_sum( int x ) {
    if( x <= n ) return f[x];
    if( map1[x] ) return map1[x];
    int sumId = g(x % mod) % mod, ans = 0;
    int l, r;
    for ( l = 2; l <= x; l = r + 1 ) {
        r = ( x / ( x / l ) );
        ans += ((1ll * (( g2(r) - g2((l - 1)) + mod ) % mod ) * phi_sum( x / l )) % mod);
        ans %= mod;
    }
    return map1[x] = ( (sumId - ans + mod) % mod );
}

int solve1( int x ) {
    int l, r, ans = 0;
    for ( l = 1; l <= x; l = r + 1 ) {
        r = ( x / ( x / l ) );
        int ret = g( x / l ) % mod;
        ret = ret * ((phi_sum(r) - phi_sum(l - 1) + mod) % mod), ret %= mod; 
        ans += ret, ans %= mod;
    }
    return (ans + mod) % mod;
}

signed main()
{
    mod = read(); 
//  inv2 = fpow( 2, mod - 2 );
    inv6 = fpow( 6, mod - 2 ); 
    int x = read();
    n = 5000000;  init();
    printf("%lld", solve1(x));
    return 0;
}