题目描述
难度分:2018
输入n(1≤n≤500),m(1≤m≤105)。
你需要构造一个1~n的排列P。
然后输入m个约束,每个约束输入L、R、X(1≤L≤X≤R≤n),表示P[L],P[L+1],…,P[R]中的最大值不能是P[X]。
输出有多少个符合要求的排列P,答案可能很大,对其模998244353再输出。
输入样例1
3 2
1 3 2
1 2 1
输出样例1
1
输入样例2
5 1
1 1 1
输出样例2
0
输入样例3
10 5
3 8 4
3 10 4
1 7 2
1 8 3
3 8 7
输出样例3
1598400
输入样例4
15 17
2 11 9
2 15 13
1 14 2
5 11 5
3 15 11
1 6 2
4 15 12
3 11 6
9 13 10
2 14 6
10 15 11
1 8 6
6 14 8
2 10 2
6 12 6
3 14 12
2 6 2
输出样例4
921467228
算法
区间DP
如果x可以是区间[L,R]中的最大值的下标,那么把最大值放置在下标x上,问题变成:
- 从R−L个非最大值中,选x−L个数分到区间[L,x−1]中的方案数(其余数分到区间[x+1,R]中)。
- 区间[L,x−1]的排列数。
- 区间[x+1,R]的排列数。
三者互相独立,根据乘法原理相乘,这样就有子问题了。
状态定义
f[L][R]表示区间[L,R]的排列数。在这个定义下,答案就应该是f[1][n]。
状态转移
base case:f[i][i−1]=1。
一般情况:枚举可以是区间[L,R]的最大值的下标x,状态转移方程为f[L][R]=ΣxC(R−L,x−L)×f[L][x−1]×f[x+1][R]。其中存在组合数的计算,因此需要预处理出逆元的辅助数组fac(阶乘取模)和finv(阶乘逆元),以便O(1)计算组合数。
怎么计算哪些x可以是最大值的下标?把输入按照右端点分组,存入到一个ban数组中,ban[R]中存的是(L,X)列表。在枚举R的过程中,把左端点≥L的x标记为不能是最大值。为什么可以这样做?如果x不能是[l,r]中的最大值的下标,那么对于任意包含[l,r]的更大的区间[L,R],x也不能是[L,R]中的最大值的下标。
复杂度分析
时间复杂度
预处理出fac和finv的时间复杂度为O(n);预处理出ban的时间复杂度为O(m);状态数量是O(n2),单次转移的时间复杂度为O(n+m),动态规划的时间复杂度为O(n2(n+m))。综上,整个算法的时间复杂度为O(n+m+n2(n+m))。
空间复杂度
算法过程中,为了得到逆元所预处理出来的数组fac、finv都是O(n)的;ban数组中存储的是O(m)个约束的信息;主要的瓶颈在于DP
矩阵f,是O(n2)的。因此,整个算法的额外空间复杂度为O(m+n2)。
python 代码
MOD, N = 998244353, 500
fac = [1] * N
for i in range(1, N):
fac[i] = fac[i - 1] * i % MOD
finv = [0] * N
finv[N - 1] = pow(fac[N - 1], MOD - 2, MOD)
for i in range(N - 2, -1, -1):
finv[i] = finv[i + 1] * (i + 1) % MOD
def C(n, k):
return fac[n] * finv[k] % MOD * finv[n - k] % MOD
n, m = map(int, input().split())
ban = [[] for _ in range(n + 1)]
for _ in range(m):
l, r, x = map(int, input().split())
ban[r].append((l, x))
f = [[0] * (n + 2) for _ in range(n + 2)]
f[n + 1][n] = 1
for l in range(n, 0, -1):
b = [False] * (n + 1)
f[l][l - 1] = 1
for r in range(l, n + 1):
for bl, bx in ban[r]:
if bl >= l:
b[bx] = True
for x in range(l, r + 1):
if not b[x]:
f[l][r] = (f[l][r] + C(r - l, x - l) * f[l][x - 1] % MOD * f[x + 1][r]) % MOD
print(f[1][n])