题解 P4718 【【模板】Pollard-Rho算法】

· · 题解

前言

半年前我曾写了一次 PR 的代码,不过当时是照着 AC 代码玄学调试(还是概率)过的,注释打的也大部分都是调试坑点,和算法正确性丝毫没有关系...(毕竟是道黑题怎么也得想办法过掉

半年后为了做 lxl 的某道题,我再次打开了这道题目。不过发现自己以前写的代码太丑了,打算重写一份。

于是和着算法导论上的代码,差不多地打了一份;交上去 T 的很惨。接着开始试着寻找关于 Pollard-Rho 算法的资料,根据对复杂度的理解再打了一份,总算是稳过了。(当然中途也少不了被数据各种卡的调试部分)

(话说过程中因看不懂算导在 bb 写什么就去网上搜了一圈,结果最后还是看算导理解的qaq)

写完又翻了下题解,发现几乎没人涉及到算法复杂度的部分,尤其是期望复杂度 \sqrt p 是怎么来的。于是就打算自己另写一篇题解了qwq

 

以上全是废话,由于我的问题语文读着还不太通,只是纪念性地写在这。看到这句话请无视以上所有文字接着看正文(

算法解析

首先我们思考一个下一项仅由上一项决定,且值域为模 n 的剩余系的序列 \text{<}x\text{>}

例如 x_i=(x_{i-1}^2+1)\pmod n

由于它某项的值仅由前一项决定,且每一项可能的取值是有限的,因此该数列一定会在经历一定次数的迭代后陷入循环

更具体的来说,只要该序列中出现一对相同的数,就会在这对数间产生环

联想到生日悖论,若这个数列是期望随机的,则它的 "环" 及 "尾" 的期望长度就是 \sqrt n

为了说明方便,下面我们就先假设由该递推式导出的序列都是足够随机的,并以此继续讨论。

 

现在我们假设 n 不是质数,或者说它至少有一对非平凡因子(指该因子不为 1n);则 n 可以表示为 n=p_1^{r_1}\cdot p_2^{r_2}\cdots p_m^{r_m}

这里我们随意取一个 p_i 并设为 p

再考虑一个序列 \text{<}x'\text{>},其中 x'_i=x_i\pmod p

我们可以发现,由于数列 \text{<}x\text{>} 的递推式仅有加减乘除运算,因此数列 \text{<}x'\text{>} 其实和它有着同样的递推式:

{x'}_i=x_{i-1}\pmod p {x'}_i=((x_{i-1}^2+1)\pmod n)\pmod p {x'}_i=(x_{i-1}^2+1)\pmod p {x'}_i=((x_{i-1}\pmod p)^2+1)\pmod p {x'}_i=({x'}_{i-1}^2+1)\pmod p

由此可以预知序列 \text{<}x'\text{>} 也会产生 "环"。而根据生日悖论,序列 \text{<}x'\text{>} 的环的长度就是期望 \sqrt p 的。

注意到当 {x'}_i\equiv {x'}_j\pmod p 时,一定有 p | (|x_i-x_j|);此时若x_i\neq x_j,通过计算 \gcd(|x_i-x_j|, n),一定就可以得到 n 的因子 p

若我们沿着 "\rho" 的路径遍历数列,不难想到,是可以在期望记录 n^{\frac 1 4} 个元素后就能找出一对 x_i, x_j,并以此计算出 n 的某个非平凡因子

算法实现

随机数

首先我们要考虑就是序列 \text{<}x\text{>} 的数是否足够随机

这里我没找到可行的证明(qaq);不过根据算法实际效率来看,这种生成方式是足够优秀的

不过为了保证数列在更多情况下依然足够随机,我们使用更一般的公式 x_i=x_{i-1}^2+c,并在初始时随机带入一个 xc 并进行迭代

不过注意当 c=0c=2 的情况应当避免(来自算法导论,书上也没有具体解释原因)

初始的实现方式

设最大的不在 "环" 上的序列编号 s,对于每一个序号 t, t>s,一定能找到一个序号 t' 使得 x_t\equiv x_{t'}\pmod p

因此我们可以考虑设一个定点 y 和一个动点 xx 根据序列的迭代公式一步步计算,期望会在 n^{\frac 1 4} 步内找到 n 的一个非平凡因子

但我们很难确定 y 是否在序列的 "环" 上。具体实现时我们倍增地设置 y 的序号即可

/*其中 rand(l, r) 代表取这个闭区间的随机数*/
ll pr(ll n){
    ll x =rand(0, n-1), y =x;
    ll c =rand(3, n-1), d =1;
    for(int i =1, k =2; d == 1; ++i){
        x =(x*x+c)%n;
        d =gcd(abs(y-x), n);
        if(i == k){
            y =x;
            k <<=1;
        }
    }
    return d;
}

Floyd 判环

上面的算法在大多数情况下已经足以应付了

但我们发现,在极少数情况下,该算法甚至在遍历整个模 n 的 "\rho" 数列后仍然无法找到解

这时候我们就要判环并重新尝试

具体来说,我们另设一个点 yx 在同一起点出发,但该点以 x 两倍的速度按公式迭代,当 x\equiv y\pmod n 时我们就找到了环。期望环的长度是 \sqrt n 的,我们可以证明该算法能在 O(\sqrt n) 的时间内判断出环

至于算法的证明,设起点距离进入环还有 s 步,环长 c,点 x 走了 k 步;则想要令 x, y 相遇就需要使 k-s\equiv 2k-s\pmod ck> s,这显然可以做到

同时注意在算法结束时,点 y 会至少遍历环一次,但 x 不一定会

基于 Floyd 的实现

由于序列 \text{<}x'\text{>} 可以看做是 \text{<}x\text{>} 的较小版本,或者说这两个序列存在一一对应的关系,我们的 Floyd 判环其实对序列 \text{<}x'\text{>} 也是适用的。

具体来说,当我们发现 Floyd 算法内的两动点满足 x\equiv y\pmod p 时(尽管我们可能不会检查这个式子),\gcd(|x-y|, n) 就会给出 n 的一个非平凡因子。而 Floyd 算法显然会比我们上面的倍增法要快许多

ll pr(ll n){
    /*因为初始跳两步的原因,该写法没法分解 4*/
    if(n == 4) return 2;
    ll x =rand(0, n-1), y =x;
    ll c =rand(3, n-1);
    ll d =1;
    x =(x*x+c)%n;
    y =(y*y+c)%n, y =(y*y+c)%n;
    while(x != y){
        x =(x*x+c)%n;
        y =(y*y+c)%n, y =(y*y+c)%n;
        d =gcd(abs(x-y), n);
        if(d != 1)
            return d;
    }
    return n;
}

初始实现的判环改进

当然,如果加上判环部分,一开始的初始实现也是可行的

ll pr(ll n){
    /*因为初始跳两步的原因,该写法没法分解 4*/
    if(n == 4) return 2;
    ll x =rand(0, n-1), y =x, y2 =x;
    ll c =rand(3, n-1);
    x =(x*x+c)%n;
    y2 =(y2*y2+c)%n, y2 =(y2*y2+c)%n;
    for(int i =1, k =2; x != y2; ++i){
        x =(x*x+c)%n;
        y2 =(y2*y2+c)%n, y2 =(y2*y2+c)%n;
        ll d =gcd(abs(x, -y), n);
        if(d != 1)
            return d;
        if(i == k){
            y =x;
            k <<=1;
        }
    }
    return n;
}

倍增积累 gcd

我们发现在模 n 意义下将每次的 abs(x-y) 相乘,若对于某一个 |x_i-y_i|d=\gcd(|x_i-y_i|, n)dn 的非平凡因子;则它们的积,这里设为 cnt,也一定有 d'=\gcd(|x_i-y_i|, n)d'n 的非平凡因子

具体来说,我们可以积累一定的样本再做 \gcd,这样能节省一些时间,且不会影响每个样本的答案

这个阈值可以设为倍增的,且可以设一个上限。以本题数据来说设为 128 会有不错的效果

这里仅给出以 "基于 Floyd 的实现" 为模板改进的代码

ll pr(ll n){
    if(n == 4) return 2;
    ll x =rand(0, n-1), y =x;
    ll c =rand(3, n-1);
    ll d =1;
    x =(x*x+c)%n;
    y =(y*y+c)%n, y =(y*y+c)%n;
    for(int lim =1; x != y; lim =min(128, lim<<1)){
        ll cnt =1;
        for(int i =0; i < lim; ++i){
            ll tmp =cnt*abs(x-y)%n;
            /*这时要么原先的 cnt 以及此时的 |x-y| 含 n 质因数 ,要么 x-y == 0*/
            /*为了避免之前的样本丢失,直接推出循环并对之前积累的样本做一次计算*/
            if(!tmp)
                break;
            cnt =tmp;
            x =(x*x+c)%n;
            y =(y*y+c)%n, y =(y*y+c)%n;
        }
        d =gcd(cnt, n);
        if(d != 1)
            return d;
    }
    return n;
}

CODE

代码里有详尽的注译

其中 pr() 函数中,我将 "初始实现的判环改进" 注译起来了,并标为 "算法2"

同时代码中使用的随机数生成函数是 std::mt19937;不过在本题的数据表现中,其和一般的 rand() 函数差别不大

#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <cmath>
#include <random>
#define ll long long
#define ull unsigned long long
#define lll __int128
//#pragma GCC target("avx")
//#pragma GCC optimize("Ofast")
//#pragma GCC optimize(2)

/*------------------------------Base------------------------------*/

inline ll mul(ll a, ll b, const ll M){
    ll d =a*(long double)b/M;
    ll ret =a*b-d*M;
    if(ret < 0)
        ret +=M;
    if(ret >= M)
        ret -=M;
    return ret;
}

inline ll plus2(ll a, ll b, const ll M){
    ll d =(a+(long double)b)/M+0.5;
    ll ret =a+b-d*M;
    if(ret < 0)
        ret +=M;
    return ret;
}

inline ll plus(const lll a, const lll b, const lll M){
    lll c =a+b;
    if(c >= M)
        return c-M;
    else
        return c;
}

ll Pow(ll x, ll k, const ll M){
    ll ret =1;
    for(; k; x =mul(x, x, M), k >>=1) if(k&1) ret =mul(ret, x, M);
    return ret;
}

ll gcd(ll a, ll b){
    while(b) b ^=a ^=b ^=a %=b;
    return a;
}

inline ll Abs(ll x){ return (x < 0) ? -x : x; }

/*------------------------------Rand------------------------------*/

static std::mt19937 engine;

/*------------------------------Miller Robin------------------------------*/

bool mr(ll p){
    if(p < 2) return 0;
    if(p == 2) return 1;
    if(p == 3) return 1;
    std::uniform_int_distribution<ll> Rand(2, p-2);
    ll d =p-1, r =0;
    while(!(d&1)) ++r, d >>=1;
    for(ll k =0; k < 10; ++k){
        /*[2, p-2]*/
        ll a =Rand(engine);
        ll x =Pow(a, d, p);
        if(x == 1 || x == p-1) continue;
        for(int i =0; i < r-1; ++i){
            x =mul(x, x, p);
            if(x == p-1) break;
        }
        if(x != p-1) return 0;
    }
    return 1;
}

/*------------------------------Pollard Rho------------------------------*/

using std::min;

/*貌似 -c 在本题的数据的效率高些 (0.2s),原因尚不确定*/
inline ll getnext(ll x, ll c, ll n){ return plus(mul(x, x, n), -c, n); }

ll pr(ll n){
    /*因为初始跳两步的原因,下面写法均没法分解 4 (即使下面的 Rand 范围设置为 [0, n-1] )*/
    if(n == 4) return 2;
    std::uniform_int_distribution<ll> Rand(3, n-1);
    ll x =Rand(engine), y =x;
    ll c =Rand(engine);
    ll d =1;

    /*以下两种写法的期望复杂度都是正确的,但写法 1 的表现更好*/

    /*----------写法 1----------*/

    x =getnext(x, c, n);
    y =getnext(y, c, n), y =getnext(y, c, n);
    for(int lim =1; x != y; lim =min(128, lim<<1)){/*提升约 0.1s */
//  for(int lim =1; x != y; lim =lim<<1){
        ll cnt =1;
        for(int i =0; i < lim; ++i){
            ll tmp =mul(cnt, Abs(plus(x, -y, n)), n);
            if(!tmp)/*提升约 0.6s;这时要么原先的 cnt 含 n 质因数 (这时 x-y 也含 ),要么 x-y == 0*/
                break;
            cnt =tmp;
            x =getnext(x, c, n);
            y =getnext(y, c, n), y =getnext(y, c, n);
        }
        d =gcd(cnt, n);
        if(d != 1)
            return d;
    }
    return n;

    /*----------写法 2----------*/
    /*这里还加了倍增 gcd 优化,可以作为参考*/

    /*
    x =getnext(x, c, n);
    y =getnext(y, c, n), y =getnext(y, c, n);
    ll x2 =x, cnt =1;
    for(int i =1, i2 =1, k =2, lim =1; x != y; ++i, ++i2){
        x =getnext(x, c, n);
        y =getnext(y, c, n), y =getnext(y, c, n);
    //  d =gcd(Abs(plus(x, -x2, n)), n);
        ll tmp =mul(cnt, Abs(plus(x, -x2, n)), n);
        if(tmp)
            cnt =tmp;
        if(i2 == lim || !tmp || x == y){
            i2 =1;
            lim =min(128, lim<<1);
            d =gcd(cnt, n);
            if(d != 1)
                return d;
            cnt =1;
        }
    //  if(d != 1)
    //      return d;
        if(i == k){
            x2 =x;
            k <<=1;
        }
    }
    return n;*/
}

ll mxp;

inline void push(ll p){
    if(p > mxp)
        mxp =p;
}

/*函数要求保证 n 可分解*/
void dfs(ll n){
    srand(time(0));
    ll d =pr(n), d2;
    while(d == n)
        d =pr(n);
    d2 =n/d;
    if(mr(d))
        push(d);
    else
        dfs(d);
    if(mr(d2))
        push(d2);
    else
        dfs(d2);
}

ll getfact(ll n){
    mxp =0;
    if(mr(n))
        return n;
    else
        dfs(n);
    return mxp;
}

/*------------------------------Main------------------------------*/

inline ll read(){
    ll x =0; char c =getchar();
    while(c < '0' || c > '9') c =getchar();
    while(c >= '0' && c <= '9') x = (x<<3) + (x<<1) + (48^c), c =getchar();
    return x;
}

int main(){
    srand(time(0));
    for(int t =0, T =read(); t < T; ++t){
        ll n =read();
        ll fact =getfact(n);
        if(fact == n)
            puts("Prime");
        else
            printf("%lld\n", fact);
    }
}