中国剩余定理

最后更新于:2023年2月24日 下午

仅以此博文,感谢知乎好友 Vivr0

中国剩余定理也称孙子定理,是中国古代求解一次同余方程组的方法。

用现代的语言来说就是: \[ x \equiv \left\{ \begin{array}{cc} a_1 \mod m_1 \\ a_2 \mod m_2 \\ \vdots \\ a_n \mod m_n \end{array} \right. \] 且正整数组 \(m_i\) 两两互素,则对任意整数组 \(a_i\),上述方程有解,解可以写成 \(x \equiv a \mod m\)

我们不要求 \(m_i\) 两两互素也能求解,只是不一定有解,下面详细给出做法。

我们先考虑 \(n=2\) 的情形。即 \[ x \equiv \left\{ \begin{array}{cc} a_1 \mod m_1 \\ a_2 \mod m_2 \end{array} \right. \]

我们可以把方程写成

\[ \left\{ \begin{array}{lll} x - a_1 & \equiv \; 0 & \mod m_1 \\ x - a_1 & \equiv \; a_2-a_1 & \mod m_2 \end{array} \right. \]

我们设 \(d = \gcd(m_1,m_2)\),则 \(d| x-a_1\)\(d|m_2\),所以 \(d|a_2-a_1\)

我们知道对任意正整数 \(a,b\), 存在整数 \(x, y\) 使得 \(xa + yb = \gcd(a,b)\)

(最后 Python 代码注释中有给出 \(x, y\) 的详细操作)

存在 \(t_1, t_2\) 使得 \(m_1 t_1 + m_2 t_2 = gcd(m_1, m_2) = d\),所以 \[ x-a_1 \equiv \frac{a_2-a_1}{d} (m_1t_1 + m_2t_2) \equiv \frac{a_2-a_1}{d} t_1 m_1 \mod m_2 \]\(x \equiv a \mod m\),其中 \(a= a_1 + \frac{a_2-a_1}{d} t_1m_1 = \frac{t_2m_2a_1+t_1m_1a_2}{d}\)\(m = lcm(m_1,m_2) = \frac{m_1m_2}{d}\)

\(n-1\) 次上述操作,就处理了一般情况

C++ 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;

LL exgcd(LL a,LL b,LL& x,LL& y){ // ax + by = gcd(a,b)
if(b==0){
x=1;y=0;return a;
}
LL d=exgcd(b,a%b,y,x);
y-=a/b*x;
return d;
}

pair<LL,LL> crt2(LL a1,LL m1,LL a2,LL m2){ // x = ai mod mi, m_i >0
LL t1,t2,ans = a2-a1;
LL d = exgcd(m1,m2,t1,t2);
assert(ans%d == 0);
LL m = m1/d*m2;
ans = (a1+ans/d*t1%m2*m1)%m; // %m2 是避免溢出
return make_pair(ans>0?ans:ans+m,m);
}
const int N = 22;
LL a[N],m[N];
pair<LL,LL> crt(int n){ // x = a[i] mod m[i], m[i] >0
pair<LL,LL> ans = make_pair(a[0]%m[0],m[0]);
for(int i=1;i<n;++i){
ans = crt2(ans.first,ans.second,a[i],m[i]);
}
return ans;
}

int main(){
LL a1,m1,a2,m2;
while(cin>>a1>>m1>>a2>>m2){
LL ans = crt2(a1,m1,a2,m2).first;
cout<<ans<<endl;
if((ans-a1)%m1 || (ans-a2)%m2){
cout<<"something wrong"<<endl;
}
}
return 0;
}

Python 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# input : a,b natural number
# output: [gcd(a,b), x, y]
# ax + by = gcd(a,b)
# Algorithm: b(a//bx+y) + a%bx = gcd(b,a%b)
def exgcd(a,b):
if(b == 0): return [a,1,0]
[d,y,x] = exgcd(b,a%b)
return [d,x,y-a//b*x]

# input: x = ai mod m_i, mi>0, i=1,2
# output: x = a mod m
def crt2(a1,m1,a2,m2):
[d,t1,t2] = exgcd(m1,m2)
a,m = a2-a1,m1//d*m2
if(a%d): raise ValueError('No solution to crt problem')
return [(a1+a//d*t1*m1)%m,m]

# input: x = ai mod m_i, mi>0
# output: x = a mod m
def crt(a,m):
n = len(a)
if(len(m)!=n): raise ValueError('a and m must have equal length')
aa,mm = a[0],m[0]
for i in range(1,n):
[aa,mm] = crt2(aa,mm,a[i],m[i])
return [aa,mm]

if __name__ == "__main__":
[a,m]=crt([2,-4,5],[3,5,12])
print(a,m)
[a,m]=crt([2,-4,4],[3,5,12])
print(a,m)