题目描述
难度分:2200
输入n、m(1≤n≤m≤2×105),长为n的字符串s,长为m的字符串t,只包含大写英文字母。
你可以随意重排s和t中的字母。然后从t中选一个长为n的子序列t′,使得对于每个s[i],要么s[i]=t′[i],要么s[i]+1=t′[i]。比如s=AAB
,t′=ABB
是合法的。
有多少个不同的(s,t′)二元组?模998244353。
输入样例1
3 4
AMA
ANAB
输出样例1
9
输入样例2
5 8
BINUS
BINANUSA
输出样例2
120
输入样例3
15 30
BINUSUNIVERSITY
BINANUSANTARAUNIVERSITYJAKARTA
输出样例3
151362308
输入样例4
4 4
UDIN
ASEP
输出样例4
0
算法
动态规划
这个题比较容易出的一个思路就是先固定s串不动,用t串中的字母来匹配s串中的字母。
状态定义
f[i][j]表示s串中位于字母表中排名第i的字母在消耗t串中位于字母表中排名第i+1的字母j个的情况下,能够产生多少种方案。在这个定义下,如果顺序考虑字母i∈[1,26],最后一个字母z
就不存在下一个字母,f[26][0]就是答案,再乘上s串的排列数目n!Πi∈[1,26]cnta[i]就是最终答案,其中cnta[i]是s串中字母表排名第i的字母的频数,类似的定义一个cntb数组,cntb[i]是t串中字母表排名第i的字母的频数。
状态转移
对于字母i,枚举需要t串中的j个i+1来与之匹配。显然j最小就是cnta[i]−cntb[i],即t串中所有的i都被用来匹配s串中的i,只有剩下的cnta[i]−cntb[i]个i需要用t串中的i+1来匹配;而j最多就是min(cnta[i],cntb[i+1]),即s串中的所有i都用t串中的i+1来匹配,但是又不能够超过t串中i+1的总数cntb[i+1]。
因此得到状态转移方程为
f[i][j]=Σcntb[i]−cnta[i]+jk=0f[i−1][k]×Cjcnta[i]
组合数Cjcnta[i]表示s串中要选j个出来和t串中的i+1字母匹配,方案数有Cjcnta[i]个。而f[i−1][k]中就还需要枚举k,k从0开始,最多可以取到cntb[i]−(cnta[i]−j),即cntb[i]中要有cnta[i]−j个与s串中的i匹配,不能超过这个数。
这样就有O(26n)的状态数量,而状态转移也是O(n)的,肯定会超时。但是注意到Σcntb[i]−cnta[i]+jk=0f[i−1][k]是上一行状态的前缀和,所以用前缀和优化就可以O(1)转移了。
复杂度分析
时间复杂度
预处理出s串的字母频数表时间复杂度为O(n),预处理出t串的字母频数表时间复杂度为O(m)。动态规划的状态数量是O(n)级别的,单次转移的时间复杂度为O(1),动态规划的时间复杂度为O(n)。因此,整个算法的时间复杂度为O(n+m)。
空间复杂度
DP
数组f的空间消耗为O(n),其前缀和数组sum也是这个规模。为了快速计算组合数(组合数是基于O(n)长度的字符串s),需要用到逆元,预处理出阶乘余数表及其对应逆元,空间消耗为O(n)。所以,整个算法的额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 200010, MOD = 998244353;
int n, m;
char s[N], t[N];
LL inv[N], finv[N], fac[N], f[30][N], sum[30][N];
// 预处理逆元
void get_inv(int n) {
inv[0] = inv[1] = 1;
for(int i = 2; i <= n; i++) {
inv[i] = (MOD - MOD/i) * inv[MOD % i] % MOD;
}
finv[0] = finv[1] = fac[0] = fac[1] = 1;
for(int i = 2; i <= n; i++) {
fac[i] = fac[i - 1] * i % MOD;
finv[i] = finv[i - 1] * inv[i] % MOD;
}
}
// 排列数
LL A(LL n, LL m) {
if(n == 0 || m == 0) return 1;
return fac[n] * finv[n - m] % MOD;
}
// 组合数
LL C(LL n, LL m) {
if(m == 0) return 1;
if(m < 0 || m > n) return 0;
return A(n, m) * finv[m] % MOD;
}
int main() {
scanf("%d%d", &n, &m);
scanf("%s", s + 1);
scanf("%s", t + 1);
get_inv(m);
int cnta[30] = {0}, cntb[30] = {0};
for(int i = 1; i <= n; i++) {
cnta[s[i] - 'A' + 1]++;
}
for(int i = 1; i <= m; i++) {
cntb[t[i] - 'A' + 1]++;
}
f[0][0] = 1;
for(int i = 1; i <= 26; i++) {
sum[i - 1][0] = f[i - 1][0];
for(int j = 1; j <= cntb[i]; j++) {
sum[i - 1][j] = (sum[i - 1][j - 1] + f[i - 1][j]) % MOD;
}
for(int j = max(0, cnta[i] - cntb[i]); j <= min(cnta[i], cntb[i + 1]); j++) {
f[i][j] = sum[i - 1][cntb[i] - (cnta[i] - j)] * C(cnta[i], j) % MOD;
}
}
LL ans = f[26][0];
ans = ans * fac[n] % MOD;
for(int i = 1; i <= 26; i++) {
ans = ans * finv[cnta[i]] % MOD;
}
printf("%d\n", ans);
return 0;
}