题目描述
难度分:1700
输入T(≤105)表示T组数据。所有数据的n2之和≤9×106。
每组数据输入n(2≤n≤3000)和n行n列的01
矩阵。
每次操作,你可以选择一个格子(i,j),然后翻转所有满足x≥i和x−i≥|y−j|的格子(x,y)。
翻转即0
变1
,1
变0
。把矩阵元素全部变为0,最少要操作多少次?
输入样例
3
5
00100
01110
11111
11111
11111
3
100
110
110
6
010101
111101
011110
000000
111010
001110
输出样例
1
2
15
算法
贪心+二维偏序
本题的操作其实就是以(i,j)这个格子为等腰直角三角形的直角直角顶点(斜边水平),然后顺着两条直角边的方向往下延伸,把所能覆盖的矩阵内的所有格子都进行翻转。比较容易就能发现最优的操作方案就是从上往下遍历给定的矩阵s,只要满足s[i][j]=1,就在(i,j)上进行一次翻转操作,把以它为直角开口向下角能覆盖到的区域都翻转,这样得到的操作次数就是答案。
但是这样直接模拟的话就会是O(n4)的复杂度,无法接受,所以要想办法优化。可以发现,如果把矩阵逆时针旋转45°,那么以上过程中对(i,j)翻转时就相当于对以(i,j)为左上角,整个矩阵的右下角为右下角的子矩阵进行翻转,我们可以把矩阵按照以下模式进行旋转:
111
111
111
变成
00100
01010
10101
01010
00100
其中值为1的位置就对应原来矩阵中的元素,这样矩阵的变成会从n变为2n−1,规模更大。
然后就可以模拟了,遍历i∈[1,2n−1],j∈[1,2n−1]。设a为s旋转后边长为2n−1的矩阵,当(i,j)在原矩阵s中有对应元素时,分为以下两种情况:
- a[i][j]=1,且(i,j)左上角执行过操作的位置数量为偶数时,说明此时(i,j)位置就是1,需要在此处进行一次翻转操作。
- a[i][j]=0,且(i,j)左上角执行过操作的位置数量为奇数时,说明此时(i,j)位置就是1,也需要在此处进行一次翻转操作。
这个遍历过程的时间复杂度是O(n2)的,那么现在的关键问题就在于如果快速知道某个位置(i,j)的左上角有多少个位置已经进行过了操作。这时候发现i是按照从上到下遍历的,所以已经操作过的位置肯定满足行号≤i,这就转化成了一个二维偏序问题,可以用树状数组来维护有多少个列满足≤j。每次操作之前查询[1,j]上的前缀和,就可以知道(i,j)的左上方有多少个已经操作过的位置,操作完后在树状数组的j位置加上1,时间复杂度是O(log2n)的。
复杂度分析
时间复杂度
构造矩阵a的时间复杂度为O(n2)。遍历矩阵b计算答案的时间复杂度为O(n2),每次翻转操作需要对树状数组进行查询和更新,时间复杂度为O(log2n)。因此,整个算法的时间复杂度为O(n2log2n)。
空间复杂度
矩阵a的空间是O(n2)的,用于标记a[i][j]在s矩阵中是否有对应元素的st矩阵的空间也是O(n2)的(边长均为2n−1)。因此,整个算法的额外空间复杂度为O(n2)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 3005;
int T, n, a[N<<1][N<<1];
char s[N][N];
bool st[N<<1][N<<1];
class Fenwick {
public:
explicit Fenwick(int n): sums_(n + 1) {}
int lowbit(int x) {
return x&-x;
}
void add(int idx, int val) {
for(++idx; idx < sums_.size(); idx += lowbit(idx)) {
sums_[idx] += val;
}
}
int query(int idx) {
int ans = 0;
for(++idx; idx > 0; idx -= lowbit(idx)) {
ans += sums_[idx];
}
return ans;
}
int query(int left, int right) {
return query(right) - query(left - 1);
}
void init(int n) {
for(int i = 0; i <= n; i++) {
sums_[i] = 0;
}
}
private:
vector<int> sums_;
};
int main() {
scanf("%d", &T);
Fenwick tr(N<<1);
while(T--) {
scanf("%d", &n);
for(int i = 1; i <= n; i++) {
scanf("%s", s[i] + 1);
}
int m = n*2 - 1;
for(int i = 1; i <= m; i++) {
for(int j = 1; j <= m; j++) {
st[i][j] = false;
}
}
int x = 0;
for(int j = n, cs = m + 1>>1; j >= 1; j--, cs--) {
++x;
int r = 1, c = j;
for(int y = cs; y <= m && r <= n && c <= n; y += 2, r++, c++) {
a[x][y] = s[r][c] - '0';
st[x][y] = true;
}
}
for(int i = 1, cs = 1; i <= n; i++, cs++) {
int r = i, c = 1;
for(int y = cs; y <= m && r <= n && c <= n; y += 2, r++, c++) {
a[x][y] = s[r][c] - '0';
st[x][y] = true;
}
x++;
}
int ans = 0;
tr.init(m + 1);
for(int i = 1; i <= m; i++) {
for(int j = 1; j <= m; j++) {
if(!st[i][j]) continue;
if(a[i][j] == 1 && tr.query(j) % 2 == 0) {
tr.add(j, 1);
ans++;
}else if(a[i][j] == 0 && tr.query(j) % 2 == 1) {
tr.add(j, 1);
ans++;
}
}
}
printf("%d\n", ans);
}
return 0;
}