#3. 浅谈离散对数问题

Semsue

2021-10-16 21:39:35

算法·理论

0 前言

今天模拟赛遇到了一个这玩意的板子,然后我只会 bsgs,于是爆零了。

1 定义

离散对数(discrete logarithm)是一个整数 x 对于给定的 a,b,m 满足下面的方程:

a^{x}\equiv b\pmod m

记作 x=\log_{a}{b}。通常情况先我们把这叫做 【阶】,\text{index}。记作 \text{ind}_{a}{b}

显然离散对数不一定存在。比如:2^{x}\equiv 3\pmod 7

2 BSGS

(a,m)=1 时我们可以使用大步小步(Baby Step Giant Step)算法。

注意到 (a,m)=1,我们有欧拉定理 a^{\varphi(m)}\equiv 1。所以 a^k 至多有 \varphi(m) 种取值(也就是其循环节为 \varphi(m))。设 x=kB-r,0\le r\le B-1B 是我们随便取的一个数,那么有 a^{kB}\equiv ba^{r},我们预处理 a^{0},a^{1}\dots a^{B-1},枚举 k 即可求出 x(其实这个过程已经可以求出阶了)。

时间复杂度 O(B+\frac{\varphi(m)}{B}),随便根号平衡一下得 O(\sqrt{\varphi(m)})

3 Ex-BSGS

用于解决 (a,m)\neq 1 时的情况。设 d=\gcd(a,m),那么有 \frac{a}{d}a^{x-1}\equiv \frac{b}{d}\pmod {\frac{m}{d}}。于是这么一路递归除下去即可。注意判断无解,当 d 不整除 b 时即无解。

递归完之后就可以正常的 bsgs 了。

4 Pohlig–Hellman algorithm

尝试自己口胡ing。

这里我们不妨设模数是个大质数 P。我们可以找出一个原根 g,然后求 g^x\equiv h\pmod P

算法思想大概就是想把 p-1 质因数分解为 \prod p_i^{e_i},然后计算 x\equiv x_i\pmod{p_i^{e_i}}

考虑 g^{p-1}\equiv 1\pmod P。所以有

(g^x)^{\frac{p-1}{p_i^{e_i}}}=(g^{x_i+kp_i^{e_i}})^{\frac{p-1}{p_i^{e_i}}}\equiv (g^{\frac{p-1}{p_i^{e_i}}})^{x_i}\equiv h^{\frac{p-1}{p_i^{e_i}}}\pmod P

所以令 g^{\frac{p-1}{p_i^{e_i}}},h^{\frac{p-1}{p_i^{e_i}}} 取代原来的 g,h 就可以在 p_i^{e_i} 范围内求 x_i 了。

所以我们需要解决的问题变成了 g^x\equiv h\pmod P,其中 x\in [0,p_i^{e_i}-1]。考虑将 x 写成 p_i 进制数,显然有 e_i 位,从低到高逐位确定。即 x=x_0+x_1p_i+x_2p_i^2+\dots+x_{e-1}p_i^{e_i-1}。然后当我们想要求 x_j 的时候,就计算 (g^x)^{\frac{p-1}{p_i^{j+1}}},容易发现这又可以写成 g^{x_j}\equiv h\pmod P 的形式,但是这时 x_j 的范围就变成了 [0,p_i-1]。这个时候我们就可以直接 BSGS 了。

综上所述,整个算法的复杂度为 O(\sum e_i(\log P+\sqrt{p_i}))。比普通 BSGS 有了不小的提升。

5 例题

都是一些板子题。

LG3846 [TJOI2007]可爱的质数

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
template <typename T>
void read(T &x) {
    T flag = 1;
    char ch = getchar();
    for (; '0' > ch || ch > '9'; ch = getchar()) if (ch == '-') flag = -1;
    for (x = 0; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
    x *= flag;
}
ll ksc(ll a, ll b, ll m) {
    return (a * b - (ll)((long double)a / m * b) * m + m) % m;
}
ll ksm(ll a, ll b, ll m) {
    ll ret = 1;
    for (; b; b >>= 1, a = a * a % m) if (b & 1) ret = ret * a % m;
    return ret;
}
ll exgcd(ll a, ll b, ll &x, ll &y) {
    if (b == 0) {
        x = 1; y = 0;
        return a;
    }
    ll d = exgcd(b, a % b, y, x);
    y -= a / b * x;
    return d;
}
ll A[100005], B[100005];
ll crt(int n) {
    ll ans = B[1], M = A[1];
    for (int i = 2; i <= n; i++) {
        ll x0, y0;
        ll now = ((B[i] - ans) % A[i] + A[i]) % A[i];
        ll d = exgcd(M, A[i], x0, y0);
        if (now % d) return -1;
        now /= d;
        ll m = A[i] / d;
        x0 = ksc(x0, now, m);
        ans = ans + x0 * M;
        M = M / d * A[i];
    }
    ans = (ans % M + M) % M;
    return ans;
}
map<ll, ll> mp;
ll bsgs(ll a, ll b, ll p) {
    ll len = sqrt(p) + 1;
    mp.clear();
    ll base = ksm(a, len, p), val = base;
    for (ll i = len; i < p; i += len) {
        if (mp.find(val) == mp.end()) mp[val] = i;
        val = val * base % p;
    }
    ll ret = 0x7f7f7f7f7f7f7f7f;
    val = b;
    for (ll i = 0; i < len; i++) {
        if (mp.find(val) != mp.end()) {
            ret = min(ret, mp[val] - i);
        }
        val = val * a % p;
    }
    return ret;
}
int main() {
    ll p, b, n;
    read(p); read(b); read(n);
    ll ans = bsgs(b, n, p);
    if (ans == 0x7f7f7f7f7f7f7f7f) puts("no solution");
    else cout << ans << "\n";
    return 0;
}

LG4195 扩展 BSGS/exBSGS

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
template <typename T>
void read(T &x) {
    T flag = 1;
    char ch = getchar();
    for (; '0' > ch || ch > '9'; ch = getchar()) if (ch == '-') flag = -1;
    for (x = 0; '0' <= ch && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
    x *= flag;
}
ll ksc(ll a, ll b, ll m) {
    return (a * b - (ll)((long double)a / m * b) * m + m) % m;
}
ll ksm(ll a, ll b, ll m) {
    ll ret = 1;
    for (; b; b >>= 1, a = a * a % m) if (b & 1) ret = ret * a % m;
    return ret;
}
ll gcd(ll a, ll b) {
    return b == 0 ? a : gcd(b, a % b);
}
ll exgcd(ll a, ll b, ll &x, ll &y) {
    if (b == 0) {
        x = 1; y = 0;
        return a;
    }
    ll d = exgcd(b, a % b, y, x);
    y -= a / b * x;
    return d;
}
ll A[100005], B[100005];
ll crt(int n) {
    ll ans = B[1], M = A[1];
    for (int i = 2; i <= n; i++) {
        ll x0, y0;
        ll now = ((B[i] - ans) % A[i] + A[i]) % A[i];
        ll d = exgcd(M, A[i], x0, y0);
        if (now % d) return -1;
        now /= d;
        ll m = A[i] / d;
        x0 = ksc(x0, now, m);
        ans = ans + x0 * M;
        M = M / d * A[i];
    }
    ans = (ans % M + M) % M;
    return ans;
}
map<ll, ll> mp;
ll bsgs(ll a, ll b, ll p) {
    ll len = sqrt(p) + 1;
    mp.clear();
    ll base = ksm(a, len, p), val = 1;
    for (ll i = 0; i < p; i += len) {
        if (mp.find(val) == mp.end()) mp[val] = i;
        val = val * base % p;
    }
    ll ret = 0x7f7f7f7f7f7f7f7f;
    val = b;
    for (ll i = 0; i < len; i++) {
        if (mp.find(val) != mp.end() && mp[val] >= i) {
            ret = min(ret, mp[val] - i);
        }
        val = val * a % p;
    }
    return ret;
}
ll exbsgs(ll a, ll b, ll p) {
    a %= p; b %= p;
    ll d, k = 0;
    while ((d = gcd(a, p)) > 1) {
        if (b % d) return -1;
        p /= d;
        k++;
    }
    ll x0, y0;
    exgcd(ksm(a, k, p), p, x0, y0);
    x0 = (x0 % p + p) % p;
    b = b * x0 % p;
    ll ans = bsgs(a, b, p);
    if (ans == 0x7f7f7f7f7f7f7f7f) return -1;
    return ans + k;
}
int main() {
    ll a, p, b;
    while (1) {
        read(a); read(p); read(b);
        if (a == 0 && p == 0 && b == 0) break;
        ll ans = exbsgs(a, b, p);
        if (ans == -1) puts("No Solution");
        else printf("%lld\n", ans);
    }
    return 0;
}

hdu6632 discrete logarithm problem

这里 p-1 的质因子只有 2,3,那么直接上 ph 算法。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
template <typename T>
void read(T &x) {
    T sgn = 1;
    char ch = getchar();
    for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
    for (x = 0; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0';
    x *= sgn;
}
ll exgcd(ll a, ll b, ll &x0, ll &y0) {
    if (b == 0) {
        x0 = 1, y0 = 0;
        return a;
    }
    ll d = exgcd(b, a % b, y0, x0);
    y0 = y0 - a / b * x0;
    return d;
}
ll mul(ll a, ll b, ll m) {
    return (__int128)a * b % m;
}
ll ksm(ll a, ll b, ll m) {
    ll ret = 1;
    for (; b; b >>= 1, a = mul(a, a, m)) if (b & 1) ret = mul(ret, a, m);
    return ret;
}
ll work(ll m, ll p, ll e, ll g, ll h) {
    vector<int> ans; ans.resize(e);
    ll z = ksm(p, e, m);
    g = ksm(g, (m - 1) / z, m);
    h = ksm(h, (m - 1) / z, m);
    for (int i = 0; i < e; i++) {
        ans[i] = -1;
        ll gi = ksm(g, (m - 1) / p, m), hi = ksm(h, (m - 1) / ksm(p, i + 1, m), m), cur = 0;
        for (int j = 0; j < i; j++) {
            cur = (cur - mul(ans[j], (m - 1) / ksm(p, i + 1 - j, m), m - 1) + m - 1) % (m - 1);
        }
        hi = mul(hi, ksm(g, cur, m), m);
        for (int j = 0; j < p; j++) {
            if (ksm(gi, j, m) == hi) {
                ans[i] = j;
                break;
            }
        }
        if (ans[i] == -1) return -1;
    }
    ll ret = 0, now = 1;
    for (int i = 0; i < e; i++) {
        ret = (ret + mul(ans[i], now, z)) % z;
        now = now * p;
    }
    return ret;
}
ll calc(ll g, ll h, ll p) {
    ll x = p - 1, x2, x3;
    ll e[5];
    e[2] = e[3] = 0;
    while (x % 2 == 0) e[2]++, x /= 2;
    while (x % 3 == 0) e[3]++, x /= 3;
    if (!e[2]) return work(p, 3, e[3], g, h);
    if (!e[3]) return work(p, 2, e[2], g, h);
    x2 = work(p, 2, e[2], g, h);
    x3 = work(p, 3, e[3], g, h);
    if (x2 == -1 || x3 == -1) return -1;
    ll p2 = ksm(2, e[2], p);
    ll p3 = ksm(3, e[3], p);
    x2 = (x2 % p2 + p2) % p2;
    x3 = (x3 % p3 + p3) % p3;
    ll q2, q3, t;
    exgcd(p3, p2, q2, t);
    q2 = (q2 % p2 + p2) % p2;
    exgcd(p2, p3, q3, t);
    q3 = (q3 % p3 + p3) % p3;
    ll ret = ((mul(mul(x2, p3, p - 1), q2, p - 1) + mul(mul(x3, p2, p - 1), q3, p - 1)) % (p - 1) + p - 1) % (p - 1);
    return ret;
}
void solve() {
    ll p, a, b;
    read(p); read(a); read(b);
    ll g = 2;
    while (1) {
        bool flg = true;
        if ((p - 1) % 2 == 0 && ksm(g, (p - 1) / 2, p) == 1) flg = false; 
        if ((p - 1) % 3 == 0 && ksm(g, (p - 1) / 3, p) == 1) flg = false;
        if (flg == true) break;
        g++;
    }
    a = calc(g, a, p); b = calc(g, b, p);
    if (a == -1 || b == -1) return puts("-1"), void();
    ll x, y;
    ll d = exgcd(a, p - 1, x, y);
    if (b % d != 0) puts("-1");
    else {
        x = mul(x, b / d, p - 1);
        ll s = (p - 1) / d;
        x = (x % s + s) % s;
        printf("%lld\n", x);
    }
}
int main() {
    int T; read(T); while (T--) solve();
    return 0;
}

这种问题还有一个可以优化的地方,就是当模数不变的时候,假设有 T 组数据,那么可以通过更改块大小,将复杂度从 O(T\sqrt{mod}) 变成 O(\sqrt{Tmod})