这题时间卡的很紧,所以需要很多优化的地方:
比如ones数组存储的是每个二进制表示中1的个数
for (int i = 0; i < 1 << N; i ++ )
{
int cnt = 0;
for (int j = i; j; j -= lowbit(j)) cnt ++ ;
ones[i] = cnt;
}
map数组存的则是第一个1在二进制的位置
map[1] = 0, map[10] = 1, map[100] = 2…
for (int i = 0; i < N; i ++ ) map[1 << i] = i;
然后就是初始化,都初始化为1,表示每个位置都可以放数字进去
void init()
{
for (int i = 0; i < N; i ++ ) row[i] = col[i] = (1 << N) - 1;
for (int i = 0; i < 3; i ++ )
for (int j = 0; j < 3; j ++ )
cell[i][j] = (1 << N) - 1;
}
lowbit函数则是返回二进制中第一个1的数值
//eg: (100100) 返回 (100) = 4,返回4
inline int lowbit(int x)
{
return x & -x;
}
x行y列可以填哪个数字, 最后得到2^i + 2^j..+..,这些i, j就是可以填的数字,最后通过map[2^i]来得到这个数字
inline int get(int x, int y)
{
return row[x] & col[y] & cell[x / 3][y / 3];
}
接下来就是dfs了
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 9, M = 1 << N;
char str[100];
int ones[M], map[M];
int row[N], col[N], cell[3][3];
inline int lowbit(int x)
{
return x & -x;
}
void init()
{
for (int i = 0; i < N; i ++ ) row[i] = col[i] = (1 << N) - 1;
for (int i = 0; i < 3; i ++ )
for (int j = 0; j < 3; j ++ )
cell[i][j] = (1 << N) - 1;
}
inline int get(int x, int y)
{
return row[x] & col[y] & cell[x / 3][y / 3];
}
bool dfs(int cnt)
{
if (!cnt) return true;
int minv = 10;
int x, y;
for (int i = 0; i < N; i ++ )
for (int j = 0; j < N; j ++ )
if (str[i * 9 + j] == '.')
{
//选一个1的个数最少的,这样的分支数量最少
int t = ones[get(i, j)];
if (t < minv)
{
minv = t;
x = i;
y = j;
}
}
//这里i等于get(x, y)一个二进制是,表示当前位置可以放哪些数字
//ed:101001001表示可以放数字1, 4,7,9
for (int i = get(x, y); i; i -= lowbit(i))
{
int t = map[lowbit(i)];
//将数字剔除
row[x] -= 1 << t;
col[y] -= 1 << t;
cell[x / 3][y / 3] -= 1 << t;
str[x * 9 + y] = t + '1';
//如果这里return false了,则表示上面数字是不行的,所以要状态恢复
if (dfs(cnt - 1)) return true;
//状态恢复
row[x] += 1 << t;
col[y] += 1 << t;
cell[x / 3][y / 3] += 1 << t;
str[x * 9 + y] = '.';
}
return false;
}
int main()
{
for (int i = 0; i < N; i ++ ) map[1 << i] = i;
for (int i = 0; i < 1 << N; i ++ )
{
int cnt = 0;
for (int j = i; j; j -= lowbit(j)) cnt ++ ;
ones[i] = cnt;
}
while (scanf("%s", str), str[0] != 'e')
{
init();
int cnt = 0;
for (int i = 0, k = 0; i < N; i ++ )
for (int j = 0; j < N; j ++ , k ++ )
if (str[k] != '.')
{
//t = 当前数字 - 1
int t = str[k] - '1';
//这里的减法就是把当前数字从行,列,九宫格中给剔除
//表示当前行,列,九宫格都不能再放数字t
row[i] -= 1 << t;
col[j] -= 1 << t;
cell[i / 3][j / 3] -= 1 << t;
}
else cnt ++ ;
dfs(cnt);
puts(str);
}
return 0;
}