题目描述
难度分:$2018$
输入$n(1 \leq n \leq 500)$,$m(1 \leq m \leq 10^5)$。
你需要构造一个$1$~$n$的排列$P$。
然后输入$m$个约束,每个约束输入$L$、$R$、$X(1 \leq L \leq X \leq R \leq n)$,表示$P[L],P[L+1],…,P[R]$中的最大值不能是$P[X]$。
输出有多少个符合要求的排列$P$,答案可能很大,对其模$998244353$再输出。
输入样例$1$
3 2
1 3 2
1 2 1
输出样例$1$
1
输入样例$2$
5 1
1 1 1
输出样例$2$
0
输入样例$3$
10 5
3 8 4
3 10 4
1 7 2
1 8 3
3 8 7
输出样例$3$
1598400
输入样例$4$
15 17
2 11 9
2 15 13
1 14 2
5 11 5
3 15 11
1 6 2
4 15 12
3 11 6
9 13 10
2 14 6
10 15 11
1 8 6
6 14 8
2 10 2
6 12 6
3 14 12
2 6 2
输出样例$4$
921467228
算法
区间DP
如果$x$可以是区间$[L,R]$中的最大值的下标,那么把最大值放置在下标$x$上,问题变成:
- 从$R-L$个非最大值中,选$x-L$个数分到区间$[L,x-1]$中的方案数(其余数分到区间$[x+1,R]$中)。
- 区间$[L,x-1]$的排列数。
- 区间$[x+1,R]$的排列数。
三者互相独立,根据乘法原理相乘,这样就有子问题了。
状态定义
$f[L][R]$表示区间$[L,R]$的排列数。在这个定义下,答案就应该是$f[1][n]$。
状态转移
$base$ $case$:$f[i][i-1]=1$。
一般情况:枚举可以是区间$[L,R]$的最大值的下标$x$,状态转移方程为$f[L][R]=\Sigma_{x}C(R-L,x-L) \times f[L][x-1] \times f[x+1][R]$。其中存在组合数的计算,因此需要预处理出逆元的辅助数组$fac$(阶乘取模)和$finv$(阶乘逆元),以便$O(1)$计算组合数。
怎么计算哪些$x$可以是最大值的下标?把输入按照右端点分组,存入到一个$ban$数组中,$ban[R]$中存的是$(L,X)$列表。在枚举$R$的过程中,把左端点$\geq L$的$x$标记为不能是最大值。为什么可以这样做?如果$x$不能是$[l,r]$中的最大值的下标,那么对于任意包含$[l,r]$的更大的区间$[L,R]$,$x$也不能是$[L,R]$中的最大值的下标。
复杂度分析
时间复杂度
预处理出$fac$和$finv$的时间复杂度为$O(n)$;预处理出$ban$的时间复杂度为$O(m)$;状态数量是$O(n^2)$,单次转移的时间复杂度为$O(n+m)$,动态规划的时间复杂度为$O(n^2(n+m))$。综上,整个算法的时间复杂度为$O(n+m+n^2(n+m))$。
空间复杂度
算法过程中,为了得到逆元所预处理出来的数组$fac$、$finv$都是$O(n)$的;$ban$数组中存储的是$O(m)$个约束的信息;主要的瓶颈在于DP
矩阵$f$,是$O(n^2)$的。因此,整个算法的额外空间复杂度为$O(m+n^2)$。
python 代码
MOD, N = 998244353, 500
fac = [1] * N
for i in range(1, N):
fac[i] = fac[i - 1] * i % MOD
finv = [0] * N
finv[N - 1] = pow(fac[N - 1], MOD - 2, MOD)
for i in range(N - 2, -1, -1):
finv[i] = finv[i + 1] * (i + 1) % MOD
def C(n, k):
return fac[n] * finv[k] % MOD * finv[n - k] % MOD
n, m = map(int, input().split())
ban = [[] for _ in range(n + 1)]
for _ in range(m):
l, r, x = map(int, input().split())
ban[r].append((l, x))
f = [[0] * (n + 2) for _ in range(n + 2)]
f[n + 1][n] = 1
for l in range(n, 0, -1):
b = [False] * (n + 1)
f[l][l - 1] = 1
for r in range(l, n + 1):
for bl, bx in ban[r]:
if bl >= l:
b[bx] = True
for x in range(l, r + 1):
if not b[x]:
f[l][r] = (f[l][r] + C(r - l, x - l) * f[l][x - 1] % MOD * f[x + 1][r]) % MOD
print(f[1][n])