AcWing
  • 首页
  • 课程
  • 题库
  • 更多
    • 竞赛
    • 题解
    • 分享
    • 问答
    • 应用
    • 校园
  • 关闭
    历史记录
    清除记录
    猜你想搜
    AcWing热点
  • App
  • 登录/注册

AtCoder ARC183C. Not Argmax    原题链接    困难

作者: 作者的头像   pein531 ,  2025-04-21 18:01:53 · 北京 ,  所有人可见 ,  阅读 2


0


题目描述

难度分:$2018$

输入$n(1 \leq n \leq 500)$,$m(1 \leq m \leq 10^5)$。

你需要构造一个$1$~$n$的排列$P$。

然后输入$m$个约束,每个约束输入$L$、$R$、$X(1 \leq L \leq X \leq R \leq n)$,表示$P[L],P[L+1],…,P[R]$中的最大值不能是$P[X]$。

输出有多少个符合要求的排列$P$,答案可能很大,对其模$998244353$再输出。

输入样例$1$

3 2
1 3 2
1 2 1

输出样例$1$

1

输入样例$2$

5 1
1 1 1

输出样例$2$

0

输入样例$3$

10 5
3 8 4
3 10 4
1 7 2
1 8 3
3 8 7

输出样例$3$

1598400

输入样例$4$

15 17
2 11 9
2 15 13
1 14 2
5 11 5
3 15 11
1 6 2
4 15 12
3 11 6
9 13 10
2 14 6
10 15 11
1 8 6
6 14 8
2 10 2
6 12 6
3 14 12
2 6 2

输出样例$4$

921467228

算法

区间DP

如果$x$可以是区间$[L,R]$中的最大值的下标,那么把最大值放置在下标$x$上,问题变成:

  1. 从$R-L$个非最大值中,选$x-L$个数分到区间$[L,x-1]$中的方案数(其余数分到区间$[x+1,R]$中)。
  2. 区间$[L,x-1]$的排列数。
  3. 区间$[x+1,R]$的排列数。

三者互相独立,根据乘法原理相乘,这样就有子问题了。

状态定义

$f[L][R]$表示区间$[L,R]$的排列数。在这个定义下,答案就应该是$f[1][n]$。

状态转移

$base$ $case$:$f[i][i-1]=1$。

一般情况:枚举可以是区间$[L,R]$的最大值的下标$x$,状态转移方程为$f[L][R]=\Sigma_{x}C(R-L,x-L) \times f[L][x-1] \times f[x+1][R]$。其中存在组合数的计算,因此需要预处理出逆元的辅助数组$fac$(阶乘取模)和$finv$(阶乘逆元),以便$O(1)$计算组合数。

怎么计算哪些$x$可以是最大值的下标?把输入按照右端点分组,存入到一个$ban$数组中,$ban[R]$中存的是$(L,X)$列表。在枚举$R$的过程中,把左端点$\geq L$的$x$标记为不能是最大值。为什么可以这样做?如果$x$不能是$[l,r]$中的最大值的下标,那么对于任意包含$[l,r]$的更大的区间$[L,R]$,$x$也不能是$[L,R]$中的最大值的下标。

复杂度分析

时间复杂度

预处理出$fac$和$finv$的时间复杂度为$O(n)$;预处理出$ban$的时间复杂度为$O(m)$;状态数量是$O(n^2)$,单次转移的时间复杂度为$O(n+m)$,动态规划的时间复杂度为$O(n^2(n+m))$。综上,整个算法的时间复杂度为$O(n+m+n^2(n+m))$。

空间复杂度

算法过程中,为了得到逆元所预处理出来的数组$fac$、$finv$都是$O(n)$的;$ban$数组中存储的是$O(m)$个约束的信息;主要的瓶颈在于DP矩阵$f$,是$O(n^2)$的。因此,整个算法的额外空间复杂度为$O(m+n^2)$。

python 代码

MOD, N = 998244353, 500
fac = [1] * N
for i in range(1, N):
    fac[i] = fac[i - 1] * i % MOD
finv = [0] * N
finv[N - 1] = pow(fac[N - 1], MOD - 2, MOD)
for i in range(N - 2, -1, -1):
    finv[i] = finv[i + 1] * (i + 1) % MOD

def C(n, k):
    return fac[n] * finv[k] % MOD * finv[n - k] % MOD

n, m = map(int, input().split())
ban = [[] for _ in range(n + 1)]
for _ in range(m):
    l, r, x = map(int, input().split())
    ban[r].append((l, x))
f = [[0] * (n + 2) for _ in range(n + 2)]
f[n + 1][n] = 1
for l in range(n, 0, -1):
    b = [False] * (n + 1)
    f[l][l - 1] = 1
    for r in range(l, n + 1):
        for bl, bx in ban[r]:
            if bl >= l:
                b[bx] = True
        for x in range(l, r + 1):
            if not b[x]:
                f[l][r] = (f[l][r] + C(r - l, x - l) * f[l][x - 1] % MOD * f[x + 1][r]) % MOD
print(f[1][n])

0 评论

App 内打开
你确定删除吗?
1024
x

© 2018-2025 AcWing 版权所有  |  京ICP备2021015969号-2
用户协议  |  隐私政策  |  常见问题  |  联系我们
AcWing
请输入登录信息
更多登录方式: 微信图标 qq图标 qq图标
请输入绑定的邮箱地址
请输入注册信息