题目描述
难度分:1500
输入n(1≤n≤500),x(1≤x≤105)和长为n的数组a(1≤a[i]≤105)。
向a中添加尽量少的数,使得a的中位数恰好等于x。输出添加的元素个数。
注:如果n是偶数,中位数取正中间左边那个。例如a=[1,3,5,7]的中位数是3。
输入样例1
3 10
10 20 30
输出样例1
1
输入样例2
3 4
1 2 3
输出样例2
4
算法
二分答案
写复杂了,做的时候忘记考虑1≤n≤500(也不知道怎么有效利用这个条件)。想着能不能通过分情况讨论做到O(1),但是一直WA
,所以最后写了个很复杂的二分。
先对a排序,得到数组中小于x的数字个数lt_cnt,大于x的数字个数gt_cnt,以及等于x的数字个数cnt。如果a中根本就没有x,先令数组a中大于等于x的数字个数为geq_cnt=gt_cnt+cnt,分为以下三种情况:
- 如果lt_cnt<geq_cnt,只需要补一个x,然后补上geq_cnt−lt_cnt−1个小于x的数即可。
- 如果lt_cnt>geq_cnt,那就需要补一个x,然后补大于等于x的元素使数组长度为2lt_cnt+1,一共补了2lt_cnt+1−n个数。
- 如果lt_cnt=geq_cnt,只要补一个x就行。
否则a中有x,中位数的位置为pivot=⌈n2⌉,其中⌈.⌉表示对.向上取整。分为以下三种情况:
- pivot∈[j,k),其中j是第一个大于等于x的位置,k是第一个严格大于x的位置。此时x就是原数组a的中位数,不用添加任何其他元素。
- pivot<j,大于等于x的数少了,有两种方案:添加大于x的数、添加等于x的数。通过二分来确定最少添加多少个,如果添加mid个,目的就是让lt_cnt<⌈n+mid2⌉且lt_cnt+cnt≥⌈n+mid2⌉。意思就是小于x的数不能触碰到中位数的位置,同时等于x的数要覆盖住中位数的位置。
- pivot≥k,小于等于x的数少了,有两种方案:添加小于x的数、添加等于x的数。还是对这两种情况分别做二分答案,看哪种添加的数更少。
复杂度分析
时间复杂度
先对a数组进行了排序,在最差情况下还要对a数组进行两次二分答案,时间复杂度为O(log2n)。因此,整个算法的时间瓶颈在于排序,时间复杂度为O(nlog2n)。
空间复杂度
只使用了有限几个变量,额外空间复杂度为O(1)。
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 501;
int n, x, a[N];
int main() {
scanf("%d%d", &n, &x);
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
sort(a + 1, a + n + 1);
int j = lower_bound(a + 1, a + n + 1, x) - a;
int k = upper_bound(a + 1, a + n + 1, x) - a;
if(j <= n) {
int cnt = k - j;
int lt_cnt = j - 1; // 小于x的个数
int gt_cnt = n - k + 1; // 大于x的个数
int geq_cnt = cnt + gt_cnt;
if(a[j] == x) {
// 数组a中有x这个数
int pivot = n + 1 >> 1;
if(j <= pivot && pivot < k) {
puts("0");
}else {
if(pivot < j) {
int l = 1, r = n;
// 加大于x的数
while(l < r) {
int mid = l + r >> 1;
if(lt_cnt + cnt >= (n + mid + 1) / 2) {
if(lt_cnt < (n + mid + 1) / 2) {
r = mid;
}else {
l = mid + 1;
}
}else {
l = mid + 1;
}
}
int t = r;
// 加等于x的数
l = 1, r = n;
while(l < r) {
int mid = l + r >> 1;
if(lt_cnt + cnt + mid >= (n + mid + 1) / 2) {
if(lt_cnt < (n + mid + 1) / 2) {
r = mid;
}else {
l = mid + 1;
}
}else {
l = mid + 1;
}
}
t = min(r, t);
printf("%d\n", t);
}else {
// 加等于x的数
int l = 1, r = n;
while(l < r) {
int mid = l + r >> 1;
if(lt_cnt + mid + cnt >= (n + mid + 1)/2) {
if(lt_cnt < (n + mid + 1) / 2) {
r = mid;
}else {
l = mid + 1;
}
}else {
l = mid + 1;
}
}
// 加小于x的数
int t = r;
l = 1, r = n;
while(l < r) {
int mid = l + r >> 1;
if(lt_cnt + mid + cnt >= (n + mid + 1)/2) {
if(lt_cnt + mid < (n + mid + 1) / 2) {
r = mid;
}else {
l = mid + 1;
}
}else {
l = mid + 1;
}
}
t = min(t, r);
printf("%d\n", t);
}
}
}else {
// 数组中没有x这个数
if(geq_cnt > lt_cnt) {
printf("%d\n", geq_cnt - lt_cnt);
}else if(geq_cnt < lt_cnt) {
printf("%d\n", lt_cnt*2 + 1 - n);
}else {
puts("1");
}
}
}else {
// 数组a中全是小于x的数
printf("%d\n", n + 1);
}
return 0;
}