题目描述
给你两个长度为 n
的字符串 s1
和 s2
,以及一个字符串 evil
。请你返回 好字符串 的数目。
好字符串 的定义为:它的长度为 n
,字典序大于等于 s1
,字典序小于等于 s2
,且不包含 evil
为子字符串。
由于答案可能很大,请你返回答案对 10^9 + 7
取余的结果。
样例
输入:n = 2, s1 = "aa", s2 = "da", evil = "b"
输出:51
解释:总共有 25 个以 'a' 开头的好字符串:"aa","ac","ad",...,"az"。
还有 25 个以 'c' 开头的好字符串:"ca","cc","cd",...,"cz"。
最后,还有一个以 'd' 开头的好字符串:"da"。
输入:n = 8, s1 = "leetcode", s2 = "leetgoes", evil = "leet"
输出:0
解释:所有字典序大于等于 s1 且小于等于 s2 的字符串都以 evil 字符串 "leet" 开头。所以没有好字符串。
输入:n = 2, s1 = "gx", s2 = "gz", evil = "x"
输出:2
限制
s1.length == n
s2.length == n
1 <= n <= 500
1 <= evil.length <= 50
- 所有字符串都只包含小写英文字母。
算法1
(动态规划 + KMP + 数位统计) $O(nm^3)$
- 我们可以计算出小于
s2
合法字符串的数量减去小于s1
合法字符串的数量,然后根据s2
的情况决定是否加 1。 - 假定我们要求小于
s
合法字符串的数量,我们从逐位开始统计。假设当前统计以前缀s[0...i-1]
开头的合法字符串的数量,枚举a
到s[i] - 1
可以作为当前第 $i$ 个字符,[i + 1, n - 1]
的字符无限制,统计这个情况下合法字符串的个数。我们需要预处理,$g(matched, i)$ 表示已经匹配了 $matched$ 个evil
上的字符时,长度为 $i$ 的合法字符串的个数。这一部分可以通过动态规划来预处理。 - 动态规划的状态表示为 $f(i, j)$ 表示前 $i$ 个字符,最大匹配长度为 $j$ 的方案数。
时间复杂度
- 动态规划预处理的时间复杂度为 $O(nm^3)$,其中 $m$ 为
evil
字符串的长度。 - 数位统计的时间复杂度为 $O(nm)$,故总时间复杂度为 $O(nm^3)$。
空间复杂度
- 需要额外 $O(nm)$ 的空间存储预处理的状态。
C++ 代码
class Solution {
public:
#define LL long long
static const int mod = 1000000007;
int m;
vector<vector<int>> f, g;
vector<int> p;
void init(int n, const string &evil) {
p[0] = -1;
int j = -1;
for (int i = 1; i < m; i++) {
while (j > -1 && evil[j + 1] != evil[i]) j = p[j];
if (evil[j + 1] == evil[i])
j++;
p[i] = j;
}
for (int matched = 0; matched < m; matched++) {
f[0][matched] = 1;
g[matched][0] = 1;
for (int i = 1; i < n; i++) {
for (int j = 0; j < m; j++)
f[i][j] = 0;
for (int j = 0; j < m; j++) {
int cnt = 26;
vector<bool> vis(26, false);
int k = j - 1;
while (1) {
if (!vis[evil[k + 1] - 'a']) {
cnt--;
vis[evil[k + 1] - 'a'] = true;
f[i][k + 2] = (f[i][k + 2] + f[i - 1][j]) % mod;
}
if (k == -1)
break;
k = p[k];
}
f[i][0] = (f[i][0] + (LL)(f[i - 1][j]) * cnt % mod) % mod;
}
for (int j = 0; j < m; j++)
g[matched][i] = (g[matched][i] + f[i][j]) % mod;
}
f[0][matched] = 0;
}
}
int solve(int n, const string &s, const string &evil) {
int ans = 0;
int j = -1;
for (int i = 0; i < n; i++) {
for (char c = 'a'; c < s[i]; c++) {
int k = j;
while (k > -1 && evil[k + 1] != c) k = p[k];
if (evil[k + 1] == c)
k++;
if (k == m - 1)
continue;
ans = (ans + g[k + 1][n - i - 1]) % mod;
}
while (j > -1 && evil[j + 1] != s[i])
j = p[j];
if (evil[j + 1] == s[i])
j++;
if (j == m - 1)
break;
}
return ans;
}
int findGoodStrings(int n, string s1, string s2, string evil) {
m = evil.size();
p.resize(m);
f.resize(n, vector<int>(m + 1, 0));
g.resize(m, vector<int>(n, 0));
init(n, evil);
return ((solve(n, s2, evil) - solve(n, s1, evil)
+ (int)(s2.find(evil) == string::npos)) % mod + mod) % mod;
}
};
算法2
(动态规划 + KMP) $O(nm^2)$
- 我们直接考虑用动态规划来求出答案。
- 设字符串的下标均从 1 开始。状态 $f(i, j, t)$ 表示考虑了前 $i$ 个字符,当前与
evil
的匹配长度为j
,与s1
和s2
的关系为 $t$ 时的方案数。 - 这里的 $t$ 是一个两位的二进制数字
(00, 01, 10, 11)
。00
表示后边随便添加字符,都不会影响与s1
和s2
的合法性。01
表示当前字符需要大于等于s1[i]
才能合法,10
同理,11
表示当前字符需要同时满足大于等于s1[i]
且 小于等于s2[i]
。 - 初始值 $f(0, 0, 3) = 1$,其余为 0,因为一开始
s1
和s2
都是被限制的。转移时,枚举上一次的匹配长度 $j$,枚举一个新字符c
,根据上一次匹配长度和c
,求出放置c
后的匹配长度 $k$。然后分析添加c
后新的状态 $t$。 - 这里只叙述
s1[i] < s2[i]
的情况,其余s1[i] == s2[i]
或s1[i] > s2[i]
的情况可以类似地推出。 - 在
s1[i] < s2[i]
时- 无限制
00
可以转移到无限制00
- 如果
c < s1[i]
,则10
也可以转移00
,因为此时c < s2[i]
,添加了c
则s2
就可以变得无限制。不能转移01
或11
,因为c
在这两种情况下将不合法。 - 如果
c == s1[i]
,则01
可以转移到01
,10
可以转移到00
,11
可以转移到01
。 - 如果
s1[i] < c < s2[i]
,则所有情况都可以转移到无限制。 - 如果
c == s2[i]
,则类似于c == s1[i]
。 - 如果
c > s2[i]
,则类似于c < s1[i]
。
- 无限制
- 最终结果为
\sum f(n, j, k)
,其中 $0 \le j < m$,$k = 0, 1, 2, 3$。
时间复杂度
- 状态数为 $O(nm)$,每次转移需要用
KMP
的next
函数,最坏情况下,求next
到 $0$ 需要 $O(m)$ 的时间,故总时间复杂度为 $O(nm^2)$。
空间复杂度
- 需要 $O(nm)$ 的空间存储状态。
C++ 代码
class Solution {
public:
static const int mod = 1000000007;
void add(int &x, int y) {
x = (x + y) % mod;
}
int findGoodStrings(int n, string s1, string s2, string evil) {
int m = evil.size();
vector<int> p(m + 1);
vector<int> a1(n + 1), a2(n + 1), e(m + 1);
for (int i = 1; i <= n; i++) {
a1[i] = s1[i - 1] - 'a';
a2[i] = s2[i - 1] - 'a';
}
for (int i = 1; i <= m; i++)
e[i] = evil[i - 1] - 'a';
p[1] = 0;
int j = 0;
for (int i = 2; i <= m; i++) {
while (j > 0 && e[j + 1] != e[i]) j = p[j];
if (e[j + 1] == e[i])
j++;
p[i] = j;
}
int f[510][50][4];
memset(f, 0, sizeof(f));
f[0][0][3] = 1;
for (int i = 1; i <= n; i++)
for (int j = 0; j < m; j++)
for (int c = 0; c < 26; c++) {
int k = j;
while (k > 0 && e[k + 1] != c)
k = p[k];
if (e[k + 1] == c)
k++;
add(f[i][k][0], f[i - 1][j][0]);
if (a1[i] < a2[i]) {
if (c < a1[i]) {
add(f[i][k][0], f[i - 1][j][2]);
} else if (c == a1[i]) {
add(f[i][k][1], f[i - 1][j][1]);
add(f[i][k][0], f[i - 1][j][2]);
add(f[i][k][1], f[i - 1][j][3]);
} else if (c > a1[i] && c < a2[i]) {
add(f[i][k][0], f[i - 1][j][1]);
add(f[i][k][0], f[i - 1][j][2]);
add(f[i][k][0], f[i - 1][j][3]);
} else if (c == a2[i]) {
add(f[i][k][0], f[i - 1][j][1]);
add(f[i][k][2], f[i - 1][j][2]);
add(f[i][k][2], f[i - 1][j][3]);
} else if (c > a2[i]) {
add(f[i][k][0], f[i - 1][j][1]);
}
} else if (a1[i] == a2[i]) {
if (c < a1[i]) {
add(f[i][k][0], f[i - 1][j][2]);
} else if (c == a1[i]) {
add(f[i][k][1], f[i - 1][j][1]);
add(f[i][k][2], f[i - 1][j][2]);
add(f[i][k][3], f[i - 1][j][3]);
} else if (c > a1[i]) {
add(f[i][k][0], f[i - 1][j][1]);
}
} else if (a1[i] > a2[i]) {
if (c < a2[i]) {
add(f[i][k][0], f[i - 1][j][2]);
} else if (c == a2[i]) {
add(f[i][k][2], f[i - 1][j][2]);
} else if (c == a1[i]) {
add(f[i][k][1], f[i - 1][j][1]);
} else if (c > a1[i]) {
add(f[i][k][0], f[i - 1][j][1]);
}
}
}
int ans = 0;
for (int j = 0; j < m; j++)
for (int k = 0; k < 4; k++)
add(ans, f[n][j][k]);
return ans;
}
};