题目描述
难度分:1381
输入n、k(2≤k≤n≤1000且k≤10)和长为n的字符串s,只包含A
、B
、?
三种字符。
如果一个字符串不包含长度恰好为k的回文子串,我们称其为合法字符串。
把s中的所有?
替换成A
或B
,可以得到2q个不同的字符串,其中 q是s中的?
的个数。
输出这2q个字符串中的合法字符串个数,模998244353。
输入样例1
7 4
AB?A?BA
输出样例1
1
输入样例2
40 7
????????????????????????????????????????
输出样例2
116295436
输入样例3
15 5
ABABA??????????
输出样例3
0
输入样例4
40 8
?A?B??B?B?AA?A?B??B?A???B?BB?B???BA??BAA
输出样例4
259240
算法
这样的计数题目一般不是DP
就是组合数学,但是看到数据范围这么小,可以往DP
上考虑。
状压DP
。
假设A
是二进制0,B
是二进制1。
状态定义
f[i][mask]表示当前要考虑第i个位置,前缀[1,i)已经填完,最后k−1个字母是mask的情况下(mask的高位表示排在后面的字母,低位表示排在前面的字母),能够得到多少个合法串。
状态转移
分为以下两种情况:
- 如果s[i]≠
A
,那么当前位置就要填B
,应该把原来[1,i−1]的mask状态去掉最低位(因为最低位是最远的,用右移就可以实现),然后在最高位或上一个1,当i<k长度不够或mask+2k−1不是回文串时进行状态转移,状态转移方程为f[i][⌊mask2⌋+2k−2]+=f[i−1][mask]。 - 如果s[i]≠
B
,那么当前位置就要填A
,应该把原来[1,i−1]的mask状态去掉最低位(因为最低位是最远的,用右移就可以实现),然后在最高位或上一个0,相当于什么也没干。当i<k长度不够或mask不是回文串时进行状态转移,状态转移方程为f[i][⌊mask2⌋]+=f[i−1][mask]。
最后的答案就是Σmask∈[0,2k−1)f[n][mask],遍历一下n长度的串最后k−1个字母的状态,累加起来即可。
复杂度分析
时间复杂度
遍历i∈[1,n]时间复杂度为O(n),遍历mask时间复杂度为O(2k),检查mask是不是长度为k的回文时间复杂度为O(k)。单次转移的时间复杂度为O(1),因此算法整体的时间复杂度为O(nk2k)。
空间复杂度
空间瓶颈在于DP
矩阵f,状态数量为O(n2k),这也是整个算法的额外空间复杂度。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1001, K = 11, MOD = 998244353;
int n, k, f[N][1<<K];
char s[N];
bool ispalindrome(int mask) {
int l = 0, r = k - 1;
while(l < r) {
if((mask>>l&1)^(mask>>r&1)) return false;
l++, r--;
}
return true;
}
int main() {
scanf("%d%d", &n, &k);
scanf("%s", s + 1);
// 初始化第1个位置的状态
if(s[1] != 'A') f[1][1<<k - 2] = 1;
if(s[1] != 'B') f[1][0] = 1;
// 递推
for(int i = 2; i <= n; i++) {
for(int mask = 0; mask < (1<<k - 1); mask++) {
if(s[i] != 'A' && (i < k || !ispalindrome(mask|(1<<k - 1)))) {
// i位置是B
f[i][mask>>1|(1<<k - 2)] = (f[i][mask>>1|(1<<k - 2)] + f[i - 1][mask]) % MOD;
}
if(s[i] != 'B' && (i < k || !ispalindrome(mask))) {
// i位置是A
f[i][mask>>1] = (f[i][mask>>1] + f[i - 1][mask]) % MOD;
}
}
}
int ans = 0;
for(int mask = 0; mask < (1<<k - 1); mask++) {
ans = (ans + f[n][mask]) % MOD;
}
printf("%d\n", ans);
return 0;
}