题目描述
难度分:2600
输入n(1≤n≤3×105)和m(0≤m≤min(20,n×(n−1)2)。
解释:有n个雇佣兵,从中选择一些人(至少一个人),组成一支部队。m的含义见下。
然后输入n个闭区间的左右端点[Li,Ri],范围[1,n]。
解释:如果选了第i位雇佣兵,那么部队的人数必须在闭区间[Li,Ri]中。
最后输入m对数字(ai,bi),满足1≤ai<bi≤n。
解释:这m对雇佣兵相互憎恨,如果选了第ai位雇佣兵,那么不能选第bi位雇佣兵,反之亦然。
输出有多少种选法,模998244353。
输入样例1
3 0
1 1
2 3
1 3
输出样例1
3
输入样例2
3 1
1 1
2 3
1 3
2 3
输出样例2
2
算法
容斥原理
先考虑m=0的情况,枚举部队人数i=1,2,…,n,对于固定的i,有多少种选法?
我们需要知道有多少个闭区间包含i。对于每个区间[L,R],把[L,R]内的数都加一。这可以用差分数组解决。计算差分数组的前缀和,得到cnt数组,也就是有cnt[i]个区间包含i,那么选法就是Cicnt[i]。
所以m=0的时候,一共有Σi∈[1,n]Cicnt[i]种选法。
回到原问题,考虑对立事件,从不考虑m对约束的方案数中,减去不合法的方案数(选了m对约束中的)。
这可以用子集容斥计算:
枚举m对约束的子集,假设子集中有j个人,相当于这j个人一定要选,那么选法就是Σi∈[p,q]Ci−jcnt[i]−j其中p和q是这j个人对应区间的交集的左右端点。
由于m至多20,所以j至多40。对0~40这41个不同的j,预处理出Ci−jcnt[i]−j的前缀和。
复杂度分析
时间复杂度
预处理出出Ci−jcnt[i]−j的前缀和需要遍历2nm次,时间复杂度为O(nm)。子集容斥需要二进制枚举i∈[0,2m)的状态,对每个状态需要枚举子集中的人数j∈[1,m],因此时间复杂度为O(m2m)。整个算法的时间复杂度为O(nm+m2m)。
注意下面的代码在二进制枚举子集时为了防止Codeforces卡哈希而使用了一个有序集合,从而使得实际的时间复杂度多了一个log2n,实际上可以通过重写哈希函数达到理论的时间复杂度。
空间复杂度
空间的瓶颈主要在于Ci−jcnt[i]−j的前缀和数组s,空间消耗为O(nm),其余数组都是线性空间。因此,整个算法的额外空间复杂度为O(nm)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 300010, MOD = 998244353;
int n, m, ans, l[N], r[N], a[N], b[N], inv[N], f[N], s[N][50];
// 快速幂
LL fast_power(int a, int k) {
int res = 1 % MOD;
while(k) {
if(k & 1) res = (LL)res * a % MOD;
a = (LL)a * a % MOD;
k >>= 1;
}
return res;
}
// 组合数
LL C(LL x, LL y) {
if(x < y || x < 0 || y < 0) return 0;
return 1ll * f[x] * inv[x - y] % MOD * inv[y] % MOD;
}
int main() {
scanf("%d%d", &n, &m);
f[0] = inv[0] = 1;
for(int i = 1; i <= n; i++) {
f[i] = 1ll * f[i - 1] * i % MOD;
inv[i] = fast_power(f[i], MOD - 2);
}
for(int i = 1; i <= n; i++) {
scanf("%d%d", &l[i], &r[i]);
}
for(int i = 1; i <= n; i++) {
s[l[i]][0]++;
s[r[i] + 1][0]--;
}
for(int i = 1; i <= n; i++) {
s[i][0] += s[i - 1][0];
}
for(int i = 1; i <= n; i++) {
for(int j = m<<1; j >= 0; j--) {
s[i][j] = (s[i - 1][j] + C(s[i][0] - j, i - j)) % MOD;
}
}
for(int i = 1; i <= m; i++) {
scanf("%d%d", &a[i], &b[i]);
}
int ans = 0;
for(int i = 0; i < (1<<m); i++){
set<int> st;
int L = 1, R = n;
for(int j = 1; j <= m; j++) {
if(i>>j - 1&1){
L = max(L, l[a[j]]);
L = max(L, l[b[j]]);
R = min(R, r[a[j]]);
R = min(R, r[b[j]]);
st.insert(a[j]);
st.insert(b[j]);
}
}
if(L > R) continue;
int cnt = st.size();
if(__builtin_popcount(i)&1) {
ans = ((ans - s[R][cnt] + s[L - 1][cnt]) % MOD + MOD) % MOD;
}else {
ans = ((ans + s[R][cnt] - s[L - 1][cnt]) % MOD + MOD) % MOD;
}
}
printf("%d\n", ans);
return 0;
}