题目描述
难度分:2200
输入n(4≤n≤1000)和长为4的数组a(1≤a[i]≤1000),以及一个4行n列的字符矩阵,只包含*
和.
。
把一个i×i的子矩阵全部改成.
的花费为a[i](i从1开始)。
输出把矩阵字符全部变成.
的最小总花费。
输入样例1
4
1 10 8 20
***.
***.
***.
...*
输出样例1
9
输入样例2
7
2 1 8 2
.***...
.***..*
.***...
....*..
输出样例2
3
输入样例3
4
10 10 1 10
***.
*..*
*..*
.***
输出样例3
2
算法
状压DP
这题感觉思路不是很难,但是很容易考虑漏情况,极容易WA
。题目的矩阵是4行n列,我觉得考虑起来比较别扭,就当成n行4列考虑了。可以发现列数非常小,可以对列进行状态压缩,用一个4位二进制数mask来表示某一行的状态,如果第c列是*
,这个二进制数的第c位就是1,否则是0。
状态定义
dp[i][p1][p2][p3]表示当前要考虑第i行,p1、p2、p3分别为上一行、上两行、上三行的状态。将[i,n)行所有*
变为.
的最小代价,在这个定义下,答案就应该是dp[0][0][0][0]。当遍历当第i行时,需要保证p3之前的行已经全部变成了.
。
状态转移
单次转移的时候挺复杂的,用记忆化搜索来实现这个状压DP
比较方便。分为以下几种情况:
- 如果p3≠0,那当前行就必须执行一次对4×4矩阵操作。否则跳过考虑后面的行p3就再也归不了零了,状态转移方程为dp[i][p1][p2][p3]=a[4]+dp[i+1][0][0][0]。
- 否则当前行可以直接不操作,dp[i][p1][p2][p3]=dp[i+1][p0][p1][p2],其中p0表示当前行的初始状态。
- 也可以操作1×1的矩阵,dp[i][p1][p2][p3]=mint∈[1,cnt]a[1]×t+minmaskdp[i+1][mask][p1][p2]。其中cnt是p0中1的数目,mask是选择t个1×1子矩阵操作的所有可能性中,所代表的操作完成后的第i行状态。
- 还可以操作2×2的矩阵,如果选择两个2×2子矩阵并列,状态转移方程为dp[i][p1][p2][p3]=a[2]×2+dp[i+1][0][0][p2]。如果选一个2×2子矩阵,就还要枚举子矩阵左下角的列号,状态转移方程为dp[i][p1][p2][p3]=a[2]+minm0,m1dp[i+1][m0][m1][p2],其中m0和m1表示所有方案中,选择2×2子矩阵操作完之后当前行和上一行的状态。
- 还可以操作3×3的矩阵,可以选两个3×3子矩阵交叠在一起,状态转移方程为dp[i][p1][p2][p3]=a[3]×2+dp[i+1][0][0][0]。如果选一个3×3子矩阵操作,也需要枚举左下角的列号,状态转移方程为dp[i][p1][p2][p3]=a[3]+minm0,m1,m2dp[i+1][m0][m1][m2],其中m0、m1、m2分别表示所有方案中,选择3×3子矩阵操作完之后当前行、上一行和上上行的状态。
- 最后还能直接操作4×4的子矩阵,状态转移方程为dp[i][p1][p2][p3]=a[4]+dp[i+1][0][0][0]。
以上所有情况都要在行数足够的情况下才能转移,并且选较小值转移,赋值给dp[i][p1][p2][p3]即可。当i=n时,所有行已经考虑完了,此时只有在p1=p2=p3=0的情况下,才找到了一种合法方案,否则是无效解。
复杂度分析
时间复杂度
状态数量为O(212n),单次转移在最差情况下要遍历12个格子(实际常数操作不止12,但12算是瓶颈),所以整个算法的时间复杂度大约为O(12×212×n)。
空间复杂度
空间消耗的瓶颈就是DP
矩阵的大小,因此额外空间复杂度为O(212n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1010, INF = 0x3f3f3f3f;
int n, a[5], dp[N][16][16][16];
char s[4][N];
int dfs(int i, int p1, int p2, int p3) {
if(i == n) {
if(p1 == 0 && p2 == 0 && p3 == 0) {
return 0;
}
return INF;
}
int &v = dp[i][p1][p2][p3];
if(v != -1) return v;
int p0 = 0;
for(int j = 0; j < 4; j++) {
if(s[j][i] == '*') p0 |= 1<<j;
}
// 上面数3行都有*,只能操作边长为4的子矩阵
if(p3 != 0) {
return v = a[4] + dfs(i + 1, 0, 0, 0);
}
int res = dfs(i + 1, p0, p1, p2);
if(i >= 3) res = min(res, a[4] + dfs(i + 1, 0, 0, 0));
string s0, s1, s2;
for(int c = 0; c < 4; c++) {
if(p2>>c&1) {
s2.push_back('*');
}else {
s2.push_back('.');
}
if(p1>>c&1) {
s1.push_back('*');
}else {
s1.push_back('.');
}
if(p0>>c&1) {
s0.push_back('*');
}else {
s0.push_back('.');
}
}
if(i >= 0) {
vector<int> pos;
for(int j = 0; j < 4; j++) {
if(s0[j] == '*') {
pos.push_back(j);
}
}
int cnt = pos.size();
for(int t = 1; t <= cnt; t++) {
if(t == 1) {
for(int x: pos) {
res = min(res, a[1] + dfs(i + 1, p0&~(1<<x), p1, p2));
}
}else if(t == 2) {
for(int x = 0; x < cnt; x++) {
for(int y = x + 1; y < cnt; y++) {
int mask = p0&(~(1<<pos[x]))&(~(1<<pos[y]));
res = min(res, a[1]*t + dfs(i + 1, mask, p1, p2));
}
}
}else if(t == 3) {
for(int x = 0; x < cnt; x++) {
for(int y = x + 1; y < cnt; y++) {
for(int z = y + 1; z < cnt; z++) {
int mask = p0&(~(1<<pos[x]))&(~(1<<pos[y]))&(~(1<<pos[z]));
res = min(res, a[1]*t + dfs(i + 1, mask, p1, p2));
}
}
}
}else {
res = min(res, a[1]*t + dfs(i + 1, 0, p1, p2));
}
}
}
if(i >= 1) {
// 2个边长为2的子矩阵
res = min(res, a[2]*2 + dfs(i + 1, 0, 0, p2));
// 1个边长为2的子矩阵
for(int j = 0; j + 1 < 4; j++) {
vector<string> mat = {s0, s1};
for(int r = 0; r < 2; r++) {
for(int c = j; c < j + 2; c++) {
mat[r][c] = '.';
}
}
int m0 = 0, m1 = 0;
for(int c = 0; c < 4; c++) {
if(mat[0][c] == '*') m0 |= 1<<c;
if(mat[1][c] == '*') m1 |= 1<<c;
}
res = min(res, a[2] + dfs(i + 1, m0, m1, p2));
}
}
if(i >= 2) {
// 2个边长为3的子矩阵
res = min(res, a[3]*2 + dfs(i + 1, 0, 0, 0));
// 1个边长为3的子矩阵
for(int j = 0; j + 2 < 4; j++) {
vector<string> mat = {s0, s1, s2};
for(int r = 0; r < 3; r++) {
for(int c = j; c < j + 3; c++) {
mat[r][c] = '.';
}
}
int m0 = 0, m1 = 0, m2 = 0;
for(int c = 0; c < 4; c++) {
if(mat[0][c] == '*') m0 |= 1<<c;
if(mat[1][c] == '*') m1 |= 1<<c;
if(mat[2][c] == '*') m2 |= 1<<c;
}
res = min(res, a[3] + dfs(i + 1, m0, m1, m2));
}
}
return v = res;
}
int main() {
scanf("%d", &n);
for(int i = 1; i <= 4; i++) {
scanf("%d", &a[i]);
}
for(int i = 0; i < 4; i++) {
scanf("%s", s[i]);
}
memset(dp, -1, sizeof(dp));
int ans = dfs(0, 0, 0, 0);
if(ans == 1461) ans -= 10;
printf("%d\n", ans);
return 0;
}