题解 P7451【[THUSCH2017] 杜老师】

· · 题解

P7451 [THUSCH2017] 杜老师解题报告:

更好的阅读体验

题意

给定T组数据,每组数据询问区间[L,R]中选出乘积为完全平方数的子集数量。(1\leqslant T\leqslant 100,1\leqslant L\leqslant R\leqslant 10^7,\sum (R-L+1)\leqslant 6\times 10^7

分析

首先看到乘积为完全平方数的子集数量完全可以想到线性基(详见CF895C Square Subsets)。

那么我们有一个初步想法:把是否含有某个质数的奇次幂压成一个二进制数(用\text{bitset}存储),然后对这些二进制数建立一个线性基,求2的自由元数量次幂。

想法很好,但时间复杂度为O(\sum_T \sum_{i=L}^R\frac{\pi(maxR)^2}{\omega}),空间复杂度为O(\frac{\pi(maxR)^2}{\omega}),均会超过限制。(\pi(x)为不超过x的质数数量,且\pi(x)\sim\frac{x}{\ln(x)}

考虑根号分治,设NmaxR,那么\leqslant\sqrt{N}的质数最多只有453个。

枚举[L,R]中每一个数,进行质因数分解,不难发现>\sqrt{N}的质因数最多一个,且是一次幂,那么我们可以把\leqslant\sqrt{N}的质数的\text{bitset}>\sqrt{N}的质数的\text{bitset}分开处理。

对于\leqslant\sqrt{N}的质数,我们在质因数分解的时候处理出\text{bitset},然后插入线性基。

对于>\sqrt{N}的质数,我们发现最多会有R-L+1个有用的质数,每次遇到某个质数才建立某个线性基,那么只会建出很少的线性基,且每次插入的复杂度是O(1)的。

如果设len=R-L+1,那么时间复杂度被我们优化到了O(\sum_T len(\log len+\frac{\pi(\sqrt{N})^2}{\omega}),空间复杂度则是O(\frac{N\pi(\sqrt{N})}{\omega}),仍然不能通过本题。

这里就要用到一个很玄妙的结论:如果区间长度大于2\sqrt{N},那么区间里的质数一旦出现就会出现在线性基内,然后我们枚举所有质数判断是否在这个区间内出现过就好了。

时间复杂度:O(T\pi(N)+T\sqrt{N}(\log\sqrt{N}+\frac{\pi(\sqrt{N})^2}{\omega})),空间复杂度:O(N+\frac{\sqrt{N}\pi(\sqrt{N})}{\omega})

时空复杂度分析有点问题,请轻喷。

代码

#include<stdio.h>
#include<bitset>
using namespace std;
const int maxn=10000005,maxm=455,mod=998244353,t=3200;
int T,A,B,cnt,tp;
int sump[maxn],minn[maxn],ord[maxn],p[maxn],a[maxn],L[maxn],tmpord[maxn];
bitset<maxm>tmp;
bitset<maxm>b[maxm],l[maxm],val[t<<1];
void sieve(int n){
    p[1]=1,sump[1]=0,minn[1]=1;
    for(int i=2;i<=n;i++){
        if(p[i]==0)
            a[++cnt]=i,ord[i]=cnt,minn[i]=i;
        for(int j=1;j<=cnt;j++){
            if(i*a[j]>n)
                break;
            p[i*a[j]]=1,minn[i*a[j]]=minn[i];
            if(i%a[j]==0)
                break;
        }
        sump[i]=sump[i-1]+(p[i]^1);
    }
}
int ksm(int a,int b){
    int res=1;
    while(b){
        if(b&1)
            res=1ll*res*a%mod;
        a=1ll*a*a%mod,b>>=1;
    }
    return res;
}
int insert(bitset<maxm>b){
    for(int i=tp;i>=1;i--){
        if(b[i]==0)
            continue;
        if(l[i].any()==false){
            l[i]=b;
            return 1;
        }
        b^=l[i];
    }
    return 0;
}
int main(){
    sieve(10000000);
    for(tp=1;a[tp]<=t;tp++);
    scanf("%d",&T);
    while(T--){
        scanf("%d%d",&A,&B);
        if(B-A+1>2*t){
            int res=B-A+1;
            for(int i=1;i<=cnt&&a[i]<=B;i++)
                if(B/a[i]!=(A-1)/a[i])
                    res--;
            printf("%d\n",ksm(2,res));
            continue;
        }
        int res=0,tmps=0;
        for(int i=1;i<=tp;i++)
            l[i].reset();
        for(int i=A;i<=B;i++){
            int k=i,rec=minn[i];
            if(rec>t)
                k/=rec;
            tmp.reset();
            while(k>1){
                int w=minn[k],cnt=0;
                while(k%w==0)
                    k/=w,cnt^=1;
                tmp[ord[w]]=cnt;
            }
            if(rec>t){
                if(tmpord[rec]==0){
                    tmpord[rec]=++tmps,val[tmps]=tmp;
                    continue;
                }
                tmp^=val[tmpord[rec]];
            }
            res+=(insert(tmp)^1);
        }
        for(int i=A;i<=B;i++)
            tmpord[minn[i]]=0;
        printf("%d\n",ksm(2,res));
    }
    return 0;
}