算法思路
本题用$dfs$搜索求解.
搜索顺序
任意选择一个空格(加入优化搜索顺序后就不是任意了), 填入数字$1\sim 9$(加入可行性剪枝后
就不是全部数字了).
剪枝优化
-
可行性搜索: 对于每个空格, 只选择满足条件限制的数字.
-
优化搜索顺序: 优先考虑可选数字(基于可行性搜索)最少的空格.
由于每次选择的空格确定, 所以不存在冗余搜索. -
位运算优化: 将每行、列、每个$3\times 3$九宫格内可选数字$x$对应二进制位$1 << (x - 1)$
若为1
, 表示$x$可选; 否则表示不可选.
具体实现
位运算优化具体表现在:
-
对于每个空格, 其可选数字需要同时考虑其所在行、列以及$3\times 3$九宫格的可选数字.
可用位运算$\&$实现. -
对于每个空格, 用$\&$得到其可选数字的二进制表示后, 可以用
lowbit
运算快速获得集合中的
每个数字.
预处理两个数组进一步优化:
-
$map$数组: $map[1 << x] = x$, 配合每次用
lowbit
运算得到的二进制表示, 快速得到其
对应的数字. -
$ones$数组: $ones[x] = $数字$x$二进制表示中
1
的个数, 配合优化搜索顺序, 快速得到
每个空格可选数字的数目.
实现代码
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 9, M = 1 << N;
char str[N * N + 1];
int map[M], ones[M]; //辅助数组
int row[N], col[N], cell[3][3]; //二进制表示可选数字集合
void init()
{
for ( int i = 0; i < N; i ++ ) row[i] = col[i] = M - 1;
for ( int i = 0; i < 3; i ++ )
for ( int j = 0; j < 3; j ++ )
cell[i][j] = M - 1;
}
void draw(int x, int y, int t, bool is_set)
{//将空格所属行、列、3 x 3 的二进制位消去(is_set = true, 对应选择其位)
//或者加回(is_set = false, 对应恢复现场)
if (is_set)
str[x * N + y] = t + '1';
else
str[x * N + y] = '.';
int v = 1 << t;
if (!is_set) v = -v;
row[x] -= v;
col[y] -= v;
cell[x / 3][y / 3] -= v;
}
int get(int x, int y)
{//获得空格位置可选数字的二进制表示
return row[x] & col[y] & cell[x / 3][y / 3];
}
int lowbit(int x)
{
return x & -x;
}
bool dfs(int cnt)
{
if ( !cnt ) return true;
int minv = N;
int x, y; //优化搜索顺序 先找到可选数字最少的空格
for ( int i = 0, k = 0; i < N; i ++ )
for ( int j = 0; j < N; j ++, k ++ )
{
if ( str[k] == '.' )
{
int state = get(i, j); //可选数字的二进制表示
if ( ones[state] < minv )
{
minv = ones[state];
x = i, y = j;
}
}
}
int state = get(x, y);
for ( int i = state; i; i -= lowbit(i) ) //可行性剪枝
{
int t = map[lowbit(i)];
draw(x, y, t, true);
if (dfs(cnt - 1)) return true;
draw(x, y, t, false);
}
return false;
}
int main()
{
//预处理辅助数组
for ( int i = 0; i < N; i ++ ) map[1 << i] = i;
for ( int i = 0; i < M; i ++ )
for ( int j = 0; j < N; j ++ )
ones[i] += i >> j & 1;
while ( cin >> str, str[0] != 'e' )
{
// 首先将九宫格 -> 二进制表示
init();
int cnt = 0; // .的个数
for ( int i = 0, k = 0; i < N; i ++ )
for ( int j = 0; j < N; j ++, k ++ )
{
if ( str[k] != '.' )
{
int t = str[k] - '1';
draw(i, j, t, true); // 对应二进制表示
}
else cnt ++ ;
}
dfs(cnt);
puts(str);
}
return 0;
}