exgcd、exCRT 二合一

· · 算法·理论

exgcd、exCRT 二合一

符号说明: (a, b) = \gcd(a, b)

exgcd

求不定方程 ax+by = (a, b) 的任意一组解。

推导

根据辗转相除法,有 (a, b) = (b, a \bmod b)

假设已经解出 bx+(a \bmod b)y = (b, a \bmod b) 的一组解 x', y',那么有:

\begin{aligned} bx'+(a \bmod b)y' &= (b, a \bmod b) \\ bx'+(a \bmod b)y' &= (a, b) \\ bx'+(a - b × \lfloor\frac{a}{b}\rfloor)y' &= (a, b) \\ bx'+ ay' - b × \lfloor\frac{a}{b}\rfloor y' &= (a, b) \\ ay'+ b×(x'- \lfloor\frac{a}{b}\rfloor y') &= (a, b) \\ \end{aligned}

因此 ax+by = (a, b) 的一组解为:

\begin{cases} x=y' \\ y=x'- \lfloor\frac{a}{b}\rfloor y' \end{cases}

而当 b = 0 时,(a, b) = a,显然有一组解为:

\begin{cases} x=1 \\ y=0 \end{cases}

递归求解即可。

扩展

求不定方程 ax+by = c 的任意一组解。

先解出 ax+by = (a,b) 的一组解 x', y',然后移项:

\begin{aligned} ax'+by' &= (a,b)\\ a(x'×\frac{c}{(a,b)})+b(y'×\frac{c}{(a,b)}) &= \end{aligned}

因此 ax+by = c 的一组解为:

\begin{cases} x=x'×\frac{c}{(a,b)} \\ y=y'×\frac{c}{(a,b)} \end{cases}

根据裴蜀定理,若 (a, b) \nmid c,原方程无解。

代码

void exgcd(long long a, long long b, long long &x, long long &y) //ax+by=(a,b)
{
    if(b == 0)
    {
        x = 1;
        y = 0;
        return;
    }

    long long _x, _y;
    exgcd(b, a % b, _x, _y);
    x = _y;
    y = _x - (a / b) * _y;
}

long long solve(long long a, long long b, long long k, long long &x, long long &y)  //ax+by=k
{
    if(k % __gcd(a, b) != 0)
    {
        return -1;
    }
    exgcd(a, b, x, y);
    x = x * k / __gcd(a, b);
    y = y * k / __gcd(a, b);
    return 0;
}

exCRT

求下面同余方程的最正整数小解。

\begin{cases} x \equiv a_1 \pmod{m_1}\\ x \equiv a_2 \pmod{m_2}\\ \cdots\\ x \equiv a_n \pmod{m_n} \end{cases}

推导

考虑合并两个方程

\begin{cases} x \equiv a_1 \pmod{m_1}\\ x \equiv a_2 \pmod{m_2}\\ \end{cases}

x = k_1m_1+a_1 = k_2m_2+a_2 即:

\begin{aligned} k_1m_1+a_1 &= k_2m_2+a_2 \\ k_1m_1-k_2m_2 &= a_2-a_1 \end{aligned}

由于只有 k_1,k_2 未知,可用 exgcd 求解,若 k_1m_1-k_2m_2 = a_2-a_1 无解,则原方程无解。

因此方程可合并为:

x \equiv k_1m_1+a_1 \pmod{\operatorname{lcm}(m_1,m_2)} \\

x \equiv k_2m_2+a_2 \pmod{\operatorname{lcm}(m_1,m_2)} \\

依次合并每个方程即可。

代码

long long exCRT(int n, int a[], int m[])
{
    long long a1 = a[1], m1 = m[1];
    bool solved = false;
    for(int i = 2; i <= n; i++)
    {
        long long x, y;
        if(solve(m1, m[i], a[i] - a1, x, y) == -1)
        {
            return -1;
        }
        a1 = m1 * x + a1;
        m1 = __lcm(m1, m[i]);
        a1 = (a1 + m1) % m1;
    }

    a1 = (a1 + m1) % m1;
    return a1 == 0 ? m1 : a1;
}