点灯问题之高斯消元法

最后更新于:2023年2月24日 下午

在知乎的 数学&算法 专栏里看到 点灯游戏的 \(O(n^3)\) 算法,觉得挺有意思,特此记录,并且补充代码

点灯游戏简介

一层大楼共有 \(n \times n\) 个房间,每个房间都有一盏灯和一个按钮。按动一个房间的按钮后,这个房间和周围四个相邻的房间的灯的状态全部都会改变(由暗变为亮或者亮变为暗)。目标是通过按按钮把所有的灯都点亮(默认情况下全暗)。求点灯方案。

  1. 全局枚举,复杂度 \(O(2^{n^2})\)
  2. 首行枚举,复杂度 \(O(2^n)\) ,由于第一行的方案就决定了下一行的方案
  3. 线性方程组求解,复杂度 \(O(n^6)\)
  4. 上述线性方程组求解可以转化成 \(n\) 个变量的线性方程组,复杂度 \(O(n^3)\)

以上内容取自 点灯游戏的 \(O(n^3)\) 算法

最终方案做法概括:\(n^2\) 个方程 \(n^2\) 个未知数的线性方程组,由于用第\(i\)行的方程可以将第 \(i+1\) 行的未知数表示成前 \(i\) 行的线性组合,从而是第一行的现行组合,这样到最后一行。最后一行的方程还未使用,从而变成了 \(n\) 个方程 \(n\) 个未知数的线性方程组。

由于方案可能不唯一,所以用 Python 自带的 numpy 以及 scipy 都不计算奇异矩阵。所以就自己写了高斯消元法来求解。

注意 numpy 数据越界的问题!

高斯消元法普通版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import numpy as np
# numpy 是用C写的,所以本质是强类型的,需要注意

def tri(A, b, eps = 1e-6):
# 使A上三角,并返回A的秩 和 列变换px
n = len(A)
px = list(range(n))
for i in range(n):
j,k = i,i
while(k<n):
j=i
while(j<n and np.fabs(A[j,k])<eps): j+=1
if(j != n): break
k+=1
if(k == n): return i,px
if(i!=k):
px[i],px[k] = px[k],px[i]
A[:,[i,k]] = A[:,[k,i]]
if(j != i):
b[[i,j]] = b[[j,i]]
A[[i,j]] = A[[j,i]]
for j in range(i+1,n):
ratio = A[j,i]/A[i,i]
b[j] -= b[i]*ratio
A[j,i:n] -= A[i,i:n]*ratio
return n,px

def trisolve(A, b):
ans = b.copy()
n = len(A)
for i in range(n-1,-1,-1):
ans[i] = ans[i]/A[i,i]
ans[:i] -= A[:i,i]*ans[i]
return ans

def solve(AA, bb, eps = 1e-6):
# 求解 Ax = b ,其中A是矩阵,b是列向量
# 答案是 ans[:,0] + k[1] ans[:,1] + ... + k[n-r] ans[n-r]
#一定要类型转化,不然会很惨!
A = AA.copy()
b = bb.copy()
A = A.astype(np.float)
b = b.astype(np.float)
n = len(A)
r,px = tri(A,b,eps)
py = list(range(n))
for i in range(n): py[px[i]] = i
if(r == n): return trisolve(A,b)
for i in range(r,n):
if(np.fabs(b[i,0])>eps): return None
ans = np.matrix(np.zeros([n,n-r+1]))
ans[:r,0] = trisolve(A[:r,:r],b[:r])
for i in range(r,n):
ans[:r,i-r+1] = trisolve(A[:r][:r],-A[:r,i])
ans[i,i-r+1] = 1
return ans[py]

A = np.matrix('1,2;3,4')
b = np.matrix('2;4')
ans = solve(A,b)
print(ans)
print((b-A*ans[:,0]))
print(A*ans[:,1:])

高斯消元法模素数版本之点灯问题 \(O(n^3)\) 求解

解的个数取 \(\log_2\) 就是 A159257

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import numpy as np
# numpy 是用C写的,所以本质是强类型的,需要注意

def inv(a,p): # 0<a<p and gcd(a,p)=1
if(a == 1): return 1
return (p-p//a)*inv(p%a,p)%p

def trip(A, b, p = 2): # 0 <= A[i,j] < p
# 使A上三角,并返回A的秩 和 列变换px
n = len(A)
px = list(range(n))
for i in range(n):
j,k = i,i
while(k<n):
j=i
while(j<n and A[j,k]==0): j+=1
if(j != n): break
k+=1
if(k == n): return i,px
if(i!=k):
px[i],px[k] = px[k],px[i]
A[:,[i,k]] = A[:,[k,i]]
if(j!=i):
b[[i,j]] = b[[j,i]]
A[[i,j]] = A[[j,i]]
for j in range(i+1,n):
ratio = A[j,i]*inv(A[i,i],p)%p
b[j] = (b[j]-b[i]*ratio)%p
A[j,i:n] = (A[j,i:n]-A[i,i:n]*ratio)%p
return n,px

def trisolvep(A, b, p=2): # 0 <= A[i,j] < p and 0<A[i,i]<p
ans = b.copy()
n = len(A)
for i in range(n-1,-1,-1):
ans[i] = ans[i]*inv(A[i,i],p)%p
ans[:i] = (ans[:i] - A[:i,i]*ans[i])%p
return ans

def solvep(A, b, p=2):
# 求解 Ax = b ,其中A是矩阵,b是列向量
# 答案是 a[0] + k[1] a[1] + ... + k[n-r] a[n-r]
n = len(A)
for i in range(n):
for j in range(n):
A[i,j]%=p
b[i]%=p
r,px = trip(A,b,p)
py = list(range(n))
for i in range(n): py[px[i]] = i
if(r == n): return trisolvep(A,b,p)
for i in range(r,n):
if(b[i,0] != 0): return None
ans = np.matrix(np.zeros([n,n-r+1], dtype=np.int))
ans[:r,0] = trisolvep(A[:r,:r],b[:r])
for i in range(r,n):
ans[:r,i-r+1] = trisolvep(A[:r,:r],(-A[:r,i])%p)
ans[i,i-r+1] = 1
return ans[py]

def lighton(x):
# x 是行向量
n = x.size
ans = np.matrix(np.ones([n,n]),dtype = np.int)
ans[0,:] = x
for i in range(1,n):
for j in range(n):
if(i-2>=0): ans[i,j] ^= ans[i-2,j]
ans[i,j] ^= ans[i-1,j]
if(j-1>=0): ans[i,j] ^= ans[i-1,j-1]
if(j+1<n): ans[i,j] ^= ans[i-1,j+1]
return ans

def light(n):
# n 是整数
if(n == 1): return [np.matrix('1')]
x = np.matrix(np.zeros([n,n+1]),dtype = np.int)
y = np.matrix(np.zeros([n,n+1]),dtype = np.int)
# 先处理好第一二行
for i in range(n):
x[i,i] = 1
y[i,-1] = 1
for i in range(n):
y[i,:] -= x[i,:]
if(i-1>=0): y[i,:] -= x[i-1,:]
if(i+1<n): y[i,:] -= x[i+1,:]
# 第i行由它的前两行决定
for i in range(2,n):
last = np.matrix(np.zeros([n,n+1]),dtype = np.int)
for i in range(n): last[i,-1] = 1
for i in range(n):
last[i,:] -= x[i,:]
last[i,:] -= y[i,:]
if(i-1>=0): last[i,:] -= y[i-1,:]
if(i+1<n): last[i,:] -= y[i+1,:]
x = y
y = last
# 此时 x为倒数第二行,y为倒数第一行,根据最后一行灯的情况列方程
A = np.matrix(np.zeros([n,2*n]),dtype = np.int)
for i in range(n):
A[i,i] = A[i,i+n]=1
if(i-1>=0): A[i,i-1+n] = 1
if(i+1<n): A[i,i+1+n] = 1
A = A*np.vstack((x,y))
b = np.matrix(np.ones([n,1]),dtype = np.int)
ans = np.matrix(np.zeros([n,n]),dtype = np.int)
x = solvep(A[:,:n],b - A[:,-1]).T
# x 是方程的解,也就是首行的点灯情况
cnt = 2**(len(x)-1)
ans = []
for i in range(cnt):
# 这里一定要用copy而不能直接等于
x0 = np.copy(x[0,:])
index = 0
while(i):
index+=1
if(i&1): x0+=x[index,:]
i>>=1
ans.append(lighton(x0&1))
return ans

while(1): # n = 19 时方案数 2^16 = 65536,所以会比较慢
n = int(input('输入n:'))
m = light(n)
print('方案数:'+str(len(m)))
print(m)

没学 Python 之前这个操作我肯定是用 Matlab 做了。

不用 C 是因为操作矩阵的话用 C 还要写矩阵乘法。矩阵加法等操作,代码量大大提升。

不过没想到 Python 代码量也这么大,主要还是问题复杂或者说优化代码不可避免带来代码量的提高

高斯消元法对于行不满秩的情况也太繁琐了吧!怪不得它们都不实现。。。