题目描述
难度分:1900
输入T(≤105)表示T组数据。所有数据的n之和≤5×105。
每组数据输入n(1≤n≤5×105)和长为n的数组a(0≤a[i]≤n)。
称序列b为MEX
序列,如果对所有i都有|b[i]−mex(b[0],…,b[i])|≤1成立,其中mex(S)表示不在S中的最小非负整数。
输出a的非空MEX
子序列的个数,模998244353。
两个子序列只要有元素下标不同,就算不同的子序列。例如a=[0,0,0]有7个不同的非空子序列。
注:子序列不要求连续。
输入样例
4
3
0 2 1
2
1 0
5
0 0 0 0 0
4
0 1 2 3
输出样例
4
2
31
7
算法
动态规划
状态定义
dp[i][j][0]表示考虑到a[i],MEX
=j且序列中没有j+1的方案数。dp[i][j][1]表示考虑到a[i],MEX
=j且序列中有j+1的方案数。
状态转移
假设a[i]=j,首先可以确定的是:所有dp[i][j][0]都可以由dp[i−1][j][0]转移过来;所有dp[i][j][1]都可以由dp[i−1][j][1]转移过来,即不选择j这个数加入子序列。状态转移方程为dp[i][j][0]=dp[i−1][j][0],dp[i][j][0]=dp[i−1][j][0]。
如果选a[j]=j这个数加入子序列:
- dp[i][j+1][0]可以由dp[i−1][j][0]转移过来,即一个
MEX
=j的子序列加入j后MEX
=j+1。还可以从dp[i−1][j+1][0]转移过来,即一个MEX
=j+1的子序列加入j后没有任何改变。还可以由dp[i][j+1][0]转移过来,即一个MEX
=j+1且不存在j+2的子序列加入j后,仍然满足MEX
=j+1(加不加j没有影响)。 - dp[i][j+1][1]可以由dp[i−1][j+1][1]转移过来,即一个
MEX
=j+1且存在j+2的子序列加入j后,仍然满足MEX
=j+1。 - dp[i][j−1][1]可以由dp[i−1][j−1][1]转移过来,即一个
MEX
=j−1,且存在j的子序列加入j后也不会有什么改变,仍然满足MEX
=j−1。还可以由dp[i−1][j−1][0]转移过来,即一个MEX
=j−1,且不存在j的子序列加入j后仍然满足MEX
=j−1。
因此综合以上两种情况,状态转移方程为
dp[i][j+1][0]=dp[i−1][j+1][0]+dp[i−1][j][0]+dp[i−1][j+1][0]
dp[i][j+1][1]=dp[i−1][j+1][1]+dp[i−1][j+1][1]
dp[i][j−1][0]=dp[i−1][j−1][0]
dp[i][j−1][1]=dp[i][j−1][1]+dp[i][j−1][1]+dp[i][j−1][0]
最后,发现第一维i是可以不要的,因此不需要三维DP
,只需要保留后两维即可,空间是O(n)的。
复杂度分析
时间复杂度
状态数量是O(n)级别,单次转移的时间复杂度为O(1),因此算法整体的时间复杂度为O(n)。
空间复杂度
开辟了n×2的DP
数组,额外空间复杂度为O(n)。
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 500010, MOD = 998244353;
int t, n, a[N];
LL dp[N][2];
int solve() {
for(int i = 0; i <= n + 2; i++) {
dp[i][0] = dp[i][1] = 0;
}
dp[0][0] = 1; // base case
for(int i = 1; i <= n; i++) {
dp[a[i] + 1][0] = (dp[a[i] + 1][0] + dp[a[i]][0] + dp[a[i] + 1][0]) % MOD;
dp[a[i] + 1][1] = (dp[a[i] + 1][1] + dp[a[i] + 1][1]) % MOD;
if(a[i] >= 1) {
dp[a[i] - 1][1] = (dp[a[i] - 1][1] + dp[a[i] - 1][1] + dp[a[i] - 1][0]) % MOD;
}
}
LL ans = 0;
for(int i = 0; i <= n + 1; i++) {
ans = (ans + dp[i][0]) % MOD;
ans = (ans + dp[i][1]) % MOD;
}
ans = (ans - 1 + MOD) % MOD; // 减去一个空序列
return ans;
}
int main() {
scanf("%d", &t);
while(t--) {
scanf("%d", &n);
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
printf("%d\n", solve());
}
return 0;
}