伪素数集 学习笔记

· · 算法·理论

注意:这个伪素数并不是费马小定理中的伪素数。

这个技巧是 jiangly 在省选比赛讲题时讲的,应该没有很多人知道?

定义

现在有 n 个数 a_1,a_2,...,a_n,如果一个集合满足其中元素 >1 且两两互质,并且每个 a_i 都能被这个集合中的数唯一分解,那么这个集合就是 a 的伪素数集合。

常用来解决关于 \gcd,\text{lcm},k 次方的问题。

求解伪素数集合

首先因为伪素数集合中的数两两互质,所以对于每个 a_i 只要存在分解方案,那么一定就是唯一分解。

我们考虑 增量构造,对于一个 ia_{1\sim i-1} 对应的伪素数集合已经构造完成,现在考虑加入 a_i:先将 a_i 不断除以集合中的数 x,直到 a_i 不是任何一个 x 的倍数。然后如果对于所有的 x,都与 a_i 互质,那么我们显然可以直接将 a_i 插入集合。 否则找到一个 x 满足 d=\gcd(x,a_i)>1,然后将 x,a_i 都除以 d ,这样能保证 x,a_i 互质。但是还要递归插入 d

关于实现:对于 a_i 我们不用每次重新扫一遍集合,只要我们每次都将新的数加到集合末尾,那么对于已经扫过的位置,显然一直保持与 a_i 互质,不用回头扫。

代码(这样实现集合中应该是不会出现 1 和重复的数的):

vector<int> p;//伪素数集合
void insert(int x){
    for(int i=0;i<SZ(p);i++){
        while(x%p[i]==0)x/=p[i];
        int g=gcd(p[i],x);
        if(g>1)p[i]/=g,x/=g,insert(g);
    }
    if(x>1)p.psb(x);
}

复杂度分析

伪素数集合大小显然不会超过所有 a_i 的质因子的并集的大小,也就是 n\log V

伪素数集合中的数的乘积不会超过所有 a_i 的乘积,因为可以发现每次递归插入其实是将 (x,a_i)\to(\frac{x}{d},d,\frac{a_i}{d}),总的乘积除以了 d,那么最坏情况集合中的所有数的乘积就是 \prod\limits a_i

根据上文,递归一次集合中的所有数的乘积就会除以 d(d>1),那么总共最多就会递归 n\log V 次,插入 a_1\sim a_n 总时间复杂度 O(n^2\log^2 V)

其实还有 \gcd 的复杂度:递归一遍,对集合中每个数求 \gcd 时间为 \sum\limits \log x=\log\prod\limits x,根据上面的分析,就等于 \log \prod\limits a_i=n\log V,所以求 \gcd 对复杂度没有影响。

(这些都是理论复杂度,实际远远跑不满)。

例题:次方数

一个数是 k 次方当且仅当每个质因子的指数 c_j\bmod k=0。那么我们考虑哈希,给每个质因子随机分配一个权值 w_j,如果 \sum\limits (c_j\bmod k)w_j 等于 0 就是 k 次方数。如果我们能将每个 a_i 都分解质因数,就能通过枚举第二个区间的左端点,不断加入第一个区间,再通过哈希表维护即可。后面的部分是 O(n^2) 的。

但是现在我们没法将每个 a_i 都分解质因数,考虑求出 a 的伪素数集合,代替质因子集合。但是会有一些问题:假设伪素数集合中一个数 x 的质因数分解的指数为 c_{j},如果 \gcd(\gcd\{c_j\},k)>1 就会有问题(例如 x=3^2,k=4)。

所以要枚举 k 的每个质因子 p,然后如果 x 是某个 yp 次方,那么 x\leftarrow y。当然 p 改成枚举每个数也行。

代码:

#include<bits/stdc++.h>
#define psb push_back
#define fi first
#define se second
#define endl '\n'
#define int __int128
#define pii pair<int,int>
#define SZ(a) ((int)a.size())
using namespace std;
using ll=long long;
namespace IO{template<class T>void read(T &x){char ch=getchar();x=0;bool f=0;while(ch<'0'||ch>'9')ch=getchar();while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^'0'),ch=getchar();x=f?-x:x;}template<class T,class ...args>void read(T &x,args &...y){read(x),read(y...);}template<class T>void write(T x){if(x<0)putchar('-'),x=-x;if(x>9)write(x/10);putchar(x%10+'0');}void write(char ch){putchar(ch);}template<class T,class ...args>void write(T x,args ...y){write(x),write(y...);}}
using IO::read;using IO::write;
mt19937_64 rnd(time(NULL));
const int N=505,L=130;
int n,k,a[N];
vector<int> p;
int gcd(int a,int b){return !b?a:gcd(b,a%b);}
void insert(int x){
    for(int i=0;i<SZ(p);i++){
        while(x%p[i]==0)x/=p[i];
        int g=gcd(p[i],x);
        if(g>1)p[i]/=g,x/=g,insert(g);
    }
    if(x>1)p.psb(x);
}
int qpow(int a,int b,int top){
    int res=1;
    while(b){
        if(b&1){if(res>top/a)return -1;res*=a;}
        if(b>>=1){if(a>top/a)return -1;a*=a;}
    }
    return res;
}
int get(int a,int b){
    int x=powl(a,1.0/b);
    for(int i=x-2;i<=x+2;i++)
        if(i>=1&&qpow(i,b,a)==a)return i;
    return -1;
}
namespace{
    vector<int> w;
    vector<pii> fac[N];
    int t[N*L];
    template<const int MOD,const int N>
    struct hash_map{
        int head[MOD],cl[N],tg,tc;
        struct node{int key,val,nxt;} g[N];
        bool find(int x){
            int y=x%MOD;
            for(int i=head[y];i;i=g[i].nxt)if(g[i].key==x)return 1;
            return 0;
        }
        int &operator[](int x){
            int y=x%MOD;
            for(int i=head[y];i;i=g[i].nxt)if(g[i].key==x)return g[i].val;
            g[++tg]={x,0,head[y]},head[y]=tg,cl[++tc]=y;
            return g[tg].val;
        }
    };
    hash_map<100007,N*N> mp;
    int solve(){
        for(int i=SZ(p);i--;)w.psb(rnd());
        for(int i=1;i<=n;i++){
            int x=a[i];
            for(int j=0;j<SZ(p);j++){
                int c=0;
                while(x%p[j]==0)x/=p[j],c++;
                if(c%=k)fac[i].psb({j,c});
            }
        }
        int ans=0;
        for(int i=1;i<=n;i++){
            int cnt=0;
            for(int j=i;j<=n;j++){
                for(pii l:fac[j]){
                    int p=l.fi,c=l.se;
                    if(t[p])cnt-=(k-t[p])*w[p];
                    t[p]=(t[p]+c)%k;
                    if(t[p])cnt+=(k-t[p])*w[p];
                }
                ans+=mp[cnt];
            }
            memset(t,0,sizeof(int)*SZ(p));
            cnt=0;
            for(int j=i;j>=1;j--){
                for(pii l:fac[j]){
                    int p=l.fi,c=l.se;
                    if(t[p])cnt-=t[p]*w[p];
                    t[p]=(t[p]+c)%k;
                    if(t[p])cnt+=t[p]*w[p];
                }
                mp[cnt]++;
            }
            memset(t,0,sizeof(int)*SZ(p));
        }
        return ans;
    }
}
signed main(){
    read(n,k);
    for(int i=1;i<=n;i++)read(a[i]),insert(a[i]);
    for(int &i:p){
        for(int j=2;j<=120;j++){
            int rt=get(i,j);
            if(rt!=-1)i=rt;
        }
    }
    write(solve(),'\n');
    return 0;
}