题目描述
难度分:2100
输入n、k(1≤k≤n≤5×105)、d(0≤d≤109)和长为n的数组a(1≤a[i]≤109)。
你需要把a分成若干个集合,每个集合的大小至少为k,且同一集合中任意两数之差至多为d。
能否做到?输出YES
或NO
。
输入样例1
6 3 10
7 2 7 7 4 2
输出样例1
YES
输入样例2
6 2 3
4 5 3 13 4 10
输出样例2
YES
输入样例3
3 2 5
10 16 22
输出样例3
NO
算法
线段树优化DP
先贪一下,集合内任意两数之差的绝对值不超过d,那只要极差不超过d就好了。对数组排序,然后从后往前遍历,让a[i]成为某个集合的开头。
状态定义
f[i]表示后缀[i,n]能否被划分为大小至少为k(取值为1表示可以,为0表示不可以),且子段内极差≤d的若干段。在这个定义下,答案就应该是f[1]。
状态转移
初始化f[n+1]=1,这是base case。
如果要以i为集合的第一个元素,那么可以通过二分找到下一个集合的第一个元素,也就是a[j]−a[i]>d的最小j,这是下一个集合首元素的最远位置,再远就会使得当前集合[i,j)的极差超过d。而每个集合的大小至少要是k,因此这个j还可以往左移动到i+k(如果j−i<k就无解,f[i]=0)。
所以只要f[i+k:j]中存在一个值为1,f[i]=1就成立,所以状态转移方程为
f[i]=maxindex∈[i+k,j]f[index]
此时发现这是一个子数组动态求最值的操作,用线段树存储DP
数组就能在O(log2n)的时间复杂度下完成转移。
复杂度分析
时间复杂度
状态数量是O(n),单次转移的时间复杂度为O(log2n),整个算法的时间复杂度为O(nlog2n)。
空间复杂度
除开输入的原始数组a,空间消耗的瓶颈就是线段树的空间,额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 500010;
int n, k, d, a[N];
class SegmentTree {
public:
struct Info {
int l, r, v, maximum;
Info() {}
Info(int left, int right, int val): l(left), r(right), v(val), maximum(val) {}
} seg[N<<2];
explicit SegmentTree() {}
void build(int u, int l, int r) {
if(l == r) {
seg[u] = Info(l, r, -1);
}else {
int mid = l + r >> 1;
build(u<<1, l, mid);
build(u<<1|1, mid + 1, r);
pushup(u);
}
}
void modify(int pos, int val) {
modify(1, pos, val);
}
Info query(int l, int r) {
return query(1, l, r);
}
private:
void modify(int u, int pos, int val) {
if(seg[u].l == pos && seg[u].r == pos) {
seg[u] = Info(pos, pos, val);
}else {
int mid = seg[u].l + seg[u].r >> 1;
if(pos <= mid) {
modify(u<<1, pos, val);
}else {
modify(u<<1|1, pos, val);
}
pushup(u);
}
}
Info query(int u, int l, int r) {
if(l <= seg[u].l && seg[u].r <= r) return seg[u];
int mid = seg[u].l + seg[u].r >> 1;
if(r <= mid) {
return query(u<<1, l, r);
}else if(mid < l) {
return query(u<<1|1, l, r);
}else {
return merge(query(u<<1, l, r), query(u<<1|1, l, r));
}
}
void pushup(int u) {
seg[u] = merge(seg[u<<1], seg[u<<1|1]);
}
Info merge(const Info& lchild, const Info& rchild) {
Info info;
info.l = lchild.l;
info.r = rchild.r;
info.v = lchild.v + rchild.v;
info.maximum = max(lchild.maximum, rchild.maximum);
return info;
}
};
int main() {
scanf("%d%d%d", &n, &k, &d);
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
sort(a + 1, a + n + 1);
SegmentTree seg;
seg.build(1, 1, n + 1);
seg.modify(n + 1, 1);
for(int i = n; i >= 1; i--) {
int j = upper_bound(a + 1, a + n + 1, a[i] + d) - a;
int cnt = j - i;
if(cnt < k) {
seg.modify(i, 0);
}else {
int res = seg.query(i + k, j).maximum;
if(res > 0) {
seg.modify(i, 1);
}else {
seg.modify(i, 0);
}
}
}
puts(seg.query(1, 1).v == 1? "YES": "NO");
return 0;
}