题目描述
难度分:2500
输入n(1≤n≤5×105),k(1≤k≤30),m(0≤m≤5×105)。
然后输入m行,每行输入三个数L,R(1≤L≤R≤n),x(0≤x<2k),表示一个约束:
a[L]AND
a[L+1]AND
…AND
a[R]=x
其中AND
表示按位与。
输出有多少个长为n的数组a,满足上述这m个约束,且元素值在[0,2k)内。注意数组下标从1开始。
答案模998244353。
输入样例1
4 3 2
1 3 3
3 4 6
输出样例1
3
输入样例2
5 2 3
1 3 2
2 5 0
3 3 3
输出样例2
33
算法
今天在做这道题的时候不止一次感觉快要开出来了,但最终以失败告终,还是看了灵佬的题解。看来实力还远不够把一周的茶全部都做出来,需要继续努力。
动态规划+双指针
这个题比较容易看出来的一点就是各位都是独立的,按位来计算方案数,然后把每一位的方案数乘起来就是最终答案。对于某个特定的二进制位,我们用动态规划来计算它的合法方案数。
我们先将所有的操作(l,r,x)存入数组d,并遍历x的位。如果x在某一位为1,说明子数组[l,r]中的所有元素在这一位都应该是1,在[l,r]区间上加1,用差分数组a来记录;如果这一位是0,那子数组[l,r]中的元素就有可能在这一位是0(至少有一个0)。
求完差分之后再原地求前缀和,这样a[i]就是要求数组第i个元素在当前二进制位为1的条件有多少个(一共有m个条件)。
状态定义
f[i]表示考虑前i个数,第i个数在当前二进制位填0的方案数(因为方案数就取决于在什么位置可以填0,这个状态定义我这种菜鸡是真想不出来)。
规定f[0]=1,表示[1,i−1]在当前位全部是1,有一种方案。
状态转移
如果下标i必须填1,则f[i]=0;否则枚举上一个0的位置j(即子数组(j,i)内所有元素在当前位都是1)。设最小的j为minJ则有状态转移方程f[i]=f[minJ]+f[minJ+1]+…+f[i−1],接下来考虑i的转移来源minJ。
- 如果i−1恰好是某个约束区间[l,r]的右端点,且这个区间内至少要有一个0,那么转移来源不能小于l,否则[l,r]就全部是1了。
- 如果i−1不是某个约束区间[l,r]的右端点,那么f[i]对应的minJ和f[i−1]对应的minJ相同。
预处理出来一个maxL[r]数组,表示至少要有一个0的约束区间右端点为r时,左端点l的最大值。接下来我们只要能够快速得到f[minJ]+f[minJ+1]+…+f[i−1]的值就行了,注意到minJ是不会回退的,所以用双指针做滑动窗口的优化,用一个变量sum来维护这个和即可。
复杂度分析
时间复杂度
一共有k位,每位里仅存在有限几次O(n)级别的算法(差分数组预处理、前缀和、动态规划+双指针),因此算法整体的时间复杂度为O(kn)(这里认为m和n一个级别,都是O(n))。
空间复杂度
差分数组/前缀和数组a、DP
数组f,以及操作数组d都是O(n)级别的空间消耗,因此额外空间复杂度为O(n)。
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int N = 500010, MOD = 998244353;
int n, k, m, a[N], f[N];
vector<PII> d[N];
int main() {
scanf("%d%d%d", &n, &k, &m);
for(int i = 1; i <= n; i++) {
d[i].clear();
}
for(int i = 1; i <= m; i++) {
int l, r, x;
scanf("%d%d%d", &l, &r, &x);
d[r].push_back({l, x});
}
int ans = 1;
for(int i = 0; i < k; i++) {
for(int j = 1; j <= n; j++) {
f[j] = a[j] = 0;
}
f[0] = 1;
for(int r = 1; r <= n; r++) {
for(auto&pir: d[r]) {
int l = pir.first, x = pir.second;
if(x>>i&1) {
// [l,r]必须全1
++a[l], --a[r + 1];
}
}
}
for(int j = 1; j <= n; j++) {
a[j] += a[j - 1];
}
int p = 0, maxL = 0, sum = 1;
for(int r = 1; r <= n; r++) {
while(p < maxL) {
sum = (sum - f[p++] + MOD)%MOD;
}
if(!a[r]) {
f[r] = sum;
sum = (sum + f[r]) % MOD;
}
for(auto&pir: d[r]) {
int l = pir.first, x = pir.second;
if(!(x>>i&1)) {
maxL = max(maxL, l);
}
}
}
int temp = 0;
for(int j = maxL; j <= n; j++) {
temp = (temp + f[j]) % MOD;
}
ans = (LL)ans*temp%MOD;
}
printf("%d\n", ans);
return 0;
}