题目描述
难度分:1600
输入n(1≤n≤2×105),m(1≤m≤106),k(1≤k≤n)和长为n的数组a(1≤a[i]≤106)。
数轴上有n个点,要求任意长为m−1的闭区间不能包含≥k个点。问:至少要去掉多少个点?
注:闭区间[L,R]的长度为R−L。
输入样例1
3 3 2
3 5 1
输出样例1
1
输入样例2
5 10 3
12 8 18 25 1
输出样例2
0
输入样例3
7 7 2
7 3 4 1 6 5 2
输出样例3
6
输入样例4
2 2 2
1 3
输出样例4
0
算法
贪心+平衡树
这个贪心比较经典,先将a数组排序,然后从前往后贪。对于某个a[l],利用二分查找找到a[r]>a[l]+m−1的第一个位置,这样以a[l]为左端点,长度为m−1的时间段就有r−l个点。如果r−l≥k,就开始删点,为了让删除的点能够最大限度地惠泽后面的区间,贪心删除最后的r−l−k+1个点即可,然后l自增继续这个过程。
此时发现在这个过程中,对于某个位置l,[l,r)内可能已经删除过一些点了,我们需要维护这个信息,才能确定当前需要删除的点。一种无脑的做法就是用一个有序的数据结构来维护剩余的时间点,平衡树就是一个很好的选择,初始化平衡树timeline,存储a中的所有时间点,以上的删点过程都在平衡树上操作。需要注意的是,在区间[timeline[l],timeline[l]+m−1]里面的点全部都要被删除的情况下,l不用自增,因为后面的时间点会直接递补上来。
复杂度分析
时间复杂度
对a排序的时间复杂度为O(nlog2n)。接下来贪心需要遍历l,当发现r−l≥k时需要删点,在最差情况下每个点都要被删除一次,时间复杂度还是O(nlog2n)。因此,整个算法的时间复杂度为O(nlog2n)。
空间复杂度
除开输入的a数组,空间消耗就在于平衡树timeline,里面需要存n个时间点,额外空间复杂度为O(n)。
python代码
class SortedList:
def __init__(self, iterable=[], _load=200):
"""Initialize sorted list instance."""
values = sorted(iterable)
self._len = _len = len(values)
self._load = _load
self._lists = _lists = [values[i:i + _load] for i in range(0, _len, _load)]
self._list_lens = [len(_list) for _list in _lists]
self._mins = [_list[0] for _list in _lists]
self._fen_tree = []
self._rebuild = True
def _fen_build(self):
"""Build a fenwick tree instance."""
self._fen_tree[:] = self._list_lens
_fen_tree = self._fen_tree
for i in range(len(_fen_tree)):
if i | i + 1 < len(_fen_tree):
_fen_tree[i | i + 1] += _fen_tree[i]
self._rebuild = False
def _fen_update(self, index, value):
"""Update `fen_tree[index] += value`."""
if not self._rebuild:
_fen_tree = self._fen_tree
while index < len(_fen_tree):
_fen_tree[index] += value
index |= index + 1
def _fen_query(self, end):
"""Return `sum(_fen_tree[:end])`."""
if self._rebuild:
self._fen_build()
_fen_tree = self._fen_tree
x = 0
while end:
x += _fen_tree[end - 1]
end &= end - 1
return x
def _fen_findkth(self, k):
"""Return a pair of (the largest `idx` such that `sum(_fen_tree[:idx]) <= k`, `k - sum(_fen_tree[:idx])`)."""
_list_lens = self._list_lens
if k < _list_lens[0]:
return 0, k
if k >= self._len - _list_lens[-1]:
return len(_list_lens) - 1, k + _list_lens[-1] - self._len
if self._rebuild:
self._fen_build()
_fen_tree = self._fen_tree
idx = -1
for d in reversed(range(len(_fen_tree).bit_length())):
right_idx = idx + (1 << d)
if right_idx < len(_fen_tree) and k >= _fen_tree[right_idx]:
idx = right_idx
k -= _fen_tree[idx]
return idx + 1, k
def _delete(self, pos, idx):
"""Delete value at the given `(pos, idx)`."""
_lists = self._lists
_mins = self._mins
_list_lens = self._list_lens
self._len -= 1
self._fen_update(pos, -1)
del _lists[pos][idx]
_list_lens[pos] -= 1
if _list_lens[pos]:
_mins[pos] = _lists[pos][0]
else:
del _lists[pos]
del _list_lens[pos]
del _mins[pos]
self._rebuild = True
def _loc_left(self, value):
"""Return an index pair that corresponds to the first position of `value` in the sorted list."""
if not self._len:
return 0, 0
_lists = self._lists
_mins = self._mins
lo, pos = -1, len(_lists) - 1
while lo + 1 < pos:
mi = (lo + pos) >> 1
if value <= _mins[mi]:
pos = mi
else:
lo = mi
if pos and value <= _lists[pos - 1][-1]:
pos -= 1
_list = _lists[pos]
lo, idx = -1, len(_list)
while lo + 1 < idx:
mi = (lo + idx) >> 1
if value <= _list[mi]:
idx = mi
else:
lo = mi
return pos, idx
def _loc_right(self, value):
"""Return an index pair that corresponds to the last position of `value` in the sorted list."""
if not self._len:
return 0, 0
_lists = self._lists
_mins = self._mins
pos, hi = 0, len(_lists)
while pos + 1 < hi:
mi = (pos + hi) >> 1
if value < _mins[mi]:
hi = mi
else:
pos = mi
_list = _lists[pos]
lo, idx = -1, len(_list)
while lo + 1 < idx:
mi = (lo + idx) >> 1
if value < _list[mi]:
idx = mi
else:
lo = mi
return pos, idx
def add(self, value):
"""Add `value` to sorted list."""
_load = self._load
_lists = self._lists
_mins = self._mins
_list_lens = self._list_lens
self._len += 1
if _lists:
pos, idx = self._loc_right(value)
self._fen_update(pos, 1)
_list = _lists[pos]
_list.insert(idx, value)
_list_lens[pos] += 1
_mins[pos] = _list[0]
if _load + _load < len(_list):
_lists.insert(pos + 1, _list[_load:])
_list_lens.insert(pos + 1, len(_list) - _load)
_mins.insert(pos + 1, _list[_load])
_list_lens[pos] = _load
del _list[_load:]
self._rebuild = True
else:
_lists.append([value])
_mins.append(value)
_list_lens.append(1)
self._rebuild = True
def discard(self, value):
"""Remove `value` from sorted list if it is a member."""
_lists = self._lists
if _lists:
pos, idx = self._loc_right(value)
if idx and _lists[pos][idx - 1] == value:
self._delete(pos, idx - 1)
def remove(self, value):
"""Remove `value` from sorted list; `value` must be a member."""
_len = self._len
self.discard(value)
if _len == self._len:
raise ValueError('{0!r} not in list'.format(value))
def pop(self, index=-1):
"""Remove and return value at `index` in sorted list."""
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
value = self._lists[pos][idx]
self._delete(pos, idx)
return value
def bisect_left(self, value):
"""Return the first index to insert `value` in the sorted list."""
pos, idx = self._loc_left(value)
return self._fen_query(pos) + idx
def bisect_right(self, value):
"""Return the last index to insert `value` in the sorted list."""
pos, idx = self._loc_right(value)
return self._fen_query(pos) + idx
def count(self, value):
"""Return number of occurrences of `value` in the sorted list."""
return self.bisect_right(value) - self.bisect_left(value)
def __len__(self):
"""Return the size of the sorted list."""
return self._len
def __getitem__(self, index):
"""Lookup value at `index` in sorted list."""
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
return self._lists[pos][idx]
def __delitem__(self, index):
"""Remove value at `index` from sorted list."""
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
self._delete(pos, idx)
def __contains__(self, value):
"""Return true if `value` is an element of the sorted list."""
_lists = self._lists
if _lists:
pos, idx = self._loc_left(value)
return idx < len(_lists[pos]) and _lists[pos][idx] == value
return False
def __iter__(self):
"""Return an iterator over the sorted list."""
return (value for _list in self._lists for value in _list)
def __reversed__(self):
"""Return a reverse iterator over the sorted list."""
return (value for _list in reversed(self._lists) for value in reversed(_list))
def __repr__(self):
"""Return string representation of sorted list."""
return 'SortedList({0})'.format(list(self))
if __name__ == "__main__":
n, m, k = map(int, input().split())
a = list(map(int, input().split()))
a.sort()
timeline = SortedList(a)
ans = 0
l = 0
while l < len(timeline):
r = timeline.bisect_right(timeline[l] + m - 1)
cnt = r - l # [timeline[l],timeline[l]+m-1]内的点数
ok = 1
if cnt >= k:
if cnt == k:
ok = 0
delta = cnt - k + 1 # 要删除最右边cnt-k+1个点
ans += delta
dels = []
for i in range(r - delta, r):
dels.append(timeline[i])
for t in dels:
timeline.discard(t)
l += ok
print(ans)