八皇后问题与回溯法
问题介绍
八皇后问题是以国际象棋为背景的经典问题: 如何能够在8×8的国际象棋棋盘上放置八个皇后,使得任何一个皇后都无法直接吃掉其他的皇后?(皇后可以攻击与它处于同一条横行,纵列或者斜线上的其它棋子)
我们考虑的是一般化的\(n\times n\)的棋盘上放置\(n\)个皇后的问题(\(n > 1\)),理论分析表明当且仅当\(n \ge 4\)时问题有解:
- 对于\(n=4\)的最简问题,有2个解
- 对于\(n=8\)的原始问题,有92个解
回溯求解
回溯算法又称为试探法,算法思路主要为:每进行一步,都是抱着试试看的态度,如果发现当前选择并不是最好的,或者这么走下去肯定达不到目标,立刻做回退操作重新选择。这种走不通就回退再走的方法就是回溯法,属于递归求解。
对于八皇后问题可以采用回溯法进行逐步求解:逐行进行放置,假设当前考虑在第m行的放置问题,遍历第m行的所有位置,检查是否与前m-1行已放置的皇后冲突(整列检查,两个斜线检查),如果位置仍然可行,则进行临时放置,并进入下一层的放置问题,在子问题退出之后,撤销第m行的临时放置,并考虑m行的下一个位置,直到最后一行放置完成得到一个解。
回溯法是主要的思想,具体实现仍然可以分为递归和非递归(自己维护一个栈),这里只提供递归的Python实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48def add_solution(n):
global count
count += 1
print(f"solution {count}:")
for i in range(n):
for j in range(n):
if queen[i][j] == 1:
print("□ " * j + "■ " + "□ " * (n - 1 - j))
def check(row, column, n):
# 不需要检查行,因为正在进行逐行填充
# 检查整列
for k in range(n):
if queen[k][column] == 1:
return False
# 检查主对角线(左上部分)
for i, j in zip(range(row - 1, -1, -1), range(column - 1, -1, -1)):
if queen[i][j] == 1:
return False
# 检查副对角线(右上部分)
for i, j in zip(range(row - 1, -1, -1), range(column + 1, n)):
if queen[i][j] == 1:
return False
return True
# 逐行放置,包含0到n-1行
def find_in_next_row(row, n):
if row == n: # 已经完成n行的放置
add_solution(n) # 添加一个合法的解
return
# 检查当前行的所有位置
for column in range(n):
if check(row, column, n): # 检查合法性
queen[row][column] = 1 # 进行当前行的放置
find_in_next_row(row + 1, n) # 考虑下一行的问题
queen[row][column] = 0 # 撤销当前行的放置
if __name__ == "__main__":
n = int(input("N queues, N = "))
count = 0
queen = [[0 for _ in range(n)] for _ in range(n)]
find_in_next_row(0, n)
print(f"count of solutions = {count}")
例如\(N=4\)的求解结果如下
1
2
3
4
5
6
7
8
9
10
11
12N queues, N = 4
solution 1:
□ ■ □ □
□ □ □ ■
■ □ □ □
□ □ ■ □
solution 2:
□ □ ■ □
■ □ □ □
□ □ □ ■
□ ■ □ □
count of solutions = 2
关闭输出后,测试\(N=12\),可以得到解的总数为14200,耗时约15秒。
补充
附上一段可读性极差的C++求解八皇后问题的代码 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
int _(int a, int b, int c, int d, int e, int &f) {
return [&]() mutable {
for (int i = (a == e ? f++, e : 0);
(i < e)
? ((c >> (a - i) & 1) || (d >> (a + i)) & 1 || (b >> i) & 1)
? 1
: (c |= (1 << (a - i)), d |= (1 << (a + i)),
b |= (1 << i), _(a + 1, b, c, d, e, f),
c &= ~(1 << (a - i)), d &= ~(1 << (a + i)),
b &= ~(1 << i), 1)
: 0;
i++) {};
return f;
}();
}
int main(int argc, char *argv[]) { //
std::cout << _(0, 0, 0, 0, 8, argc = 0) << std::endl;
}
运行结果是输出92。