题目描述
给你两个整数 m
和 k
,以及数据流形式的若干整数。你需要实现一个数据结构,计算这个数据流的 MK 平均值。
MK 平均值 按照如下步骤计算:
- 如果数据流中的整数少于
m
个,MK 平均值 为-1
,否则将数据流中最后m
个元素拷贝到一个独立的容器中。 - 从这个容器中删除最小的
k
个数和最大的k
个数。 - 计算剩余元素的平均值,并 向下取整到最近的整数。
请你实现 MKAverage
类:
MKAverage(int m, int k)
用一个空的数据流和两个整数m
和k
初始化 MKAverage 对象。void addElement(int num)
往数据流中插入一个新的元素num
。int calculateMKAverage()
对当前的数据流计算并返回 MK 平均数,结果需 向下取整到最近的整数。
样例
输入:
[
"MKAverage",
"addElement", "addElement",
"calculateMKAverage",
"addElement",
"calculateMKAverage",
"addElement", "addElement", "addElement",
"calculateMKAverage"
]
[
[3, 1],
[3], [1],
[],
[10],
[],
[5], [5], [5],
[]
]
输出:
[null, null, null, -1, null, 3, null, null, null, 5]
解释:
MKAverage obj = new MKAverage(3, 1);
obj.addElement(3); // 当前元素为 [3]
obj.addElement(1); // 当前元素为 [3,1]
obj.calculateMKAverage(); // 返回 -1,因为 m = 3,但数据流中只有 2 个元素
obj.addElement(10); // 当前元素为 [3,1,10]
obj.calculateMKAverage(); // 最后 3 个元素为 [3,1,10]
// 删除最小以及最大的 1 个元素后,容器为 [3]
// [3] 的平均值等于 3/1 = 3,故返回 3
obj.addElement(5); // 当前元素为 [3,1,10,5]
obj.addElement(5); // 当前元素为 [3,1,10,5,5]
obj.addElement(5); // 当前元素为 [3,1,10,5,5,5]
obj.calculateMKAverage(); // 最后 3 个元素为 [5,5,5]
// 删除最小以及最大的 1 个元素后,容器为 [5]
// [5] 的平均值等于 5/1 = 5,故返回 5
限制
3 <= m <= 10^5
1 <= k*2 < m
1 <= num <= 10^5
addElement
与calculateMKAverage
总操作次数不超过10^5
次。
算法1
(二分,树状数组) 插入 $O(\log S)$;查询 $O(\log^2 S)$
- 维护两个树状数组 $f$ 和 $g$。其中 $f(i)$ 维护数字 $i$ 出现次数的前缀和,$g(i)$ 维护数字 $i$ 累加的前缀和。维护总和 $sum$。
- 插入一个数字 $x$ 时,更新树状数组 $f(x) + 1$,$g(x) + x$。
- 查询时,通过树状数组,二分查询第 $k$ 个与第 $m - k$ 个所在的数字,并根据 $g$ 求出最小的 $k$ 个数字与最大的 $k$ 个数字之和。
时间复杂度
- 插入时,更新树状数组,时间复杂度为 $O(\log S)$。其中 $S$ 为最大的数字。
- 查询时,二分第 $k$ 个与第 $m - k$ 个数字,二分过程每次查询需要 $O(\log S)$ 的时间,故单次查询的总时间为 $O(\log^2 S)$。
空间复杂度
- 需要 $O(m + S)$ 的空间存储数据流队列,以及两个树状数组。
C++ 代码
#define LL long long
#define NUM 100000
class MKAverage {
private:
queue<int> q;
LL f[NUM + 1], g[NUM + 1], sum;
int m, k;
void add(LL *f, int x, int y) {
for (; x <= NUM; x += x & -x)
f[x] += y;
}
LL query(LL *f, int x) {
LL tot = 0;
for (; x; x -= x & -x)
tot += f[x];
return tot;
}
public:
MKAverage(int m, int k) {
memset(f, 0, sizeof(f));
memset(g, 0, sizeof(g));
sum = 0;
this->m = m;
this->k = k;
}
void addElement(int num) {
q.push(num);
sum += num;
add(f, num, 1);
add(g, num, num);
if (q.size() > m) {
sum -= q.front();
add(f, q.front(), -1);
add(g, q.front(), -q.front());
q.pop();
}
}
LL left() {
int l = 1, r = NUM;
while (l < r) {
int mid = (l + r) >> 1;
if (query(f, mid) < k) l = mid + 1;
else r = mid;
}
return query(g, l) - (query(f, l) - k) * l;
}
LL right() {
int l = 1, r = NUM;
while (l < r) {
int mid = (l + r) >> 1;
if (query(f, mid) < m - k) l = mid + 1;
else r = mid;
}
return sum - (query(g, l) - (query(f, l) - (m - k)) * l);
}
int calculateMKAverage() {
if (q.size() < m) return -1;
return (sum - left() - right()) / (m - k - k);
}
};
/**
* Your MKAverage object will be instantiated and called as such:
* MKAverage* obj = new MKAverage(m, k);
* obj->addElement(num);
* int param_2 = obj->calculateMKAverage();
*/
算法2
(多重集/平衡树) 插入 $O(\log m)$;查询 $O(1)$
- 维护三个多重集:
L
记录最小的k
个数字,M
记录中间的m - k - k
个数字,R
记录最大的k
个数字。维护M
中数字的和sum
。 - 插入一个数字时,首先插入
L
。如果L
的大小超过了k
,则转移L
中最大的数字到M
。如果M
的大小超过了m - k - k
,则转移M
中最大的数字到R
。 - 这样
R
的大小为k
或者k + 1
,所以删除时,如果删除了L
中的数字,则需要从R
中取出一个最小的数字放到M
中,然后再从M
中取出一个最小的数字放到L
中。删除了M
或者R
中的元素同理。
时间复杂度
- 每次插入需要调整若干个多重集,故插入的时间复杂度为 $O(\log n)$。
- 查询时直接使用
sum
计算答案,故查询的时间复杂度为常数。
空间复杂度
- 需要 $O(m)$ 的额外空间存储数据流队列和三个多重集。
C++ 代码
#define LL long long
class MKAverage {
private:
queue<int> q;
int m, k;
LL sum;
multiset<int> L, M, R;
public:
MKAverage(int m, int k) {
sum = 0;
this->m = m;
this->k = k;
}
void transL2M() {
int x = *L.rbegin();
L.erase(L.find(x));
M.insert(x);
sum += x;
}
void transM2R() {
int x = *M.rbegin();
M.erase(M.find(x));
R.insert(x);
sum -= x;
}
void transR2M() {
int x = *R.begin();
R.erase(R.find(x));
M.insert(x);
sum += x;
}
void transM2L() {
int x = *M.begin();
M.erase(M.find(x));
L.insert(x);
sum -= x;
}
void insert(int num) {
L.insert(num);
if (L.size() > k) transL2M();
if (M.size() > m - k - k) transM2R();
}
void erase(int num) {
if (L.find(num) != L.end()) {
L.erase(L.find(num));
transR2M();
transM2L();
} else if (M.find(num) != M.end()) {
M.erase(M.find(num));
sum -= num;
transR2M();
} else {
R.erase(R.find(num));
}
}
void addElement(int num) {
q.push(num);
insert(num);
if (q.size() > m) {
erase(q.front());
q.pop();
}
}
int calculateMKAverage() {
if (q.size() < m) return -1;
return sum / (m - k - k);
}
};
/**
* Your MKAverage object will be instantiated and called as such:
* MKAverage* obj = new MKAverage(m, k);
* obj->addElement(num);
* int param_2 = obj->calculateMKAverage();
*/