需求
原需求
因为本题需求(函数getLeastNumbers_Solution):
- 输出数组内元素请按从小到大顺序排序
- 以返回值的方式返回结果
所以:
- 堆算法: 时间复杂度:O(n*log(k)), 空间复杂度:O(k)
- 快速选择算法: 时间复杂度:平均O(n+k*log(k))、最差O(n^2),空间复杂度O(n)
新需求
如果本题修改需求(函数topK):
- 修改实参,移动最小的前k个数到前面
- 前k个数不要求排序
那么:
- 堆算法: 时间复杂度:O(n*log(k)), 空间复杂度:O(k)
- 快速选择算法: 时间复杂度:平均O(n)、最差O(n^2),空间复杂度O(1)
个人觉得新需求更符合题目“最小的k个数”,且使用函数topK很容易实现函数getLeastNumbers_Solution。
C++、Java、Python3代码全部通过。
算法
分析函数topK:
- 与快速排序算法类似,每次选择一个pivot,把<=pivot的数放在左边,把>=pivot的书放在右边
- 平均每次减少一半的数据、平均选择log(n)次,所以时间复杂度平均O(n)
- 如果每次选择最大值或最小值,那么每次减少1个数、选择n次,所以时间复杂度最差O(n^2)
- 选择left、mid、right的中位数,所以最差情况不容易发生
C++代码
class Solution {
public:
static vector<int> getLeastNumbers_Solution(const vector<int> &input, int k) {
if (input.empty() || k <= 0) {
return {};
}
const int n = static_cast<int>(input.size());
vector<int> nums = input;
if (n <= k) {
sort(nums.begin(), nums.end());
return nums;
}
topK(nums, k);
nums.resize(k);
sort(nums.begin(), nums.end());
return nums;
}
static void topK(vector<int> &nums, int k) {
const int n = static_cast<int>(nums.size());
if (k <= 0 || n <= k) {
return;
}
int left = 0;
int right = n - 1;
for (;;) {
const int mid = left + ((right - left) >> 1);
moveMedianToRight(nums[left], nums[mid], nums[right]);
const int pivot = nums[right];
int i = left;
int j = right;
while (i < j) {
while (i < j && nums[i] <= pivot) {
++i;
}
while (i < j && nums[j] >= pivot) {
--j;
}
swap(nums[i], nums[j]);
}
swap(nums[i], nums[right]);
if (i + 1 < k) {
left = i + 1;
} else if (i > k) {
right = i - 1;
} else {
return;
}
}
}
static void moveMedianToRight(int &left, int &mid, int &right) {
if (left < right) {
// left < right
if (right < mid) {
// left < right < mid
return;
}
// left < right && mid <= right
if (left < mid) {
// left < mid <= right
swap(mid, right);
return;
}
// mid <= left < right
swap(left, right);
return;
} else {
// right <= left
if (left < mid) {
// right <= left < mid
swap(left, right);
return;
}
// right <= left && mid <= left
if (right < mid) {
// right < mid <= left
swap(mid, right);
}
// mid <= right <= left
return;
}
}
};
Java代码
class Solution {
public static List<Integer> getLeastNumbers_Solution(int[] input, int k) {
if (input.length == 0 || k <= 0) {
return new ArrayList<>();
}
if (input.length <= k) {
Arrays.sort(input);
List<Integer> list = new ArrayList<>(input.length);
for (int i = 0; i < input.length; ++i) {
list.add(input[i]);
}
return list;
}
topK(input, k);
Arrays.sort(input, 0, k);
List<Integer> list = new ArrayList<>(k);
for (int i = 0; i < k; ++i) {
list.add(input[i]);
}
return list;
}
public static void topK(int[] nums, int k) {
if (k <= 0 || nums.length <= k) {
return;
}
int left = 0;
int right = nums.length - 1;
for (; ; ) {
moveMedianToRight(nums, left, right);
final int pivot = nums[right];
int i = left;
int j = right;
while (i < j) {
while (i < j && nums[i] <= pivot) {
++i;
}
while (i < j && nums[j] >= pivot) {
--j;
}
swap(nums, i, j);
}
swap(nums, i, right);
if (i + 1 < k) {
left = i + 1;
} else if (i > k) {
right = i - 1;
} else {
return;
}
}
}
public static void moveMedianToRight(int[] nums, int left, int right) {
final int mid = left + ((right - left) >> 1);
if (nums[left] <= nums[right]) {
if (nums[right] <= nums[mid]) {
// left <= right <= mid
return;
}
// left <= right && mid < right
if (nums[left] <= nums[mid]) {
// left <= mid < right
swap(nums, mid, right);
return;
}
// mid < left <= right
swap(nums, left, right);
return;
} else {
// right < left
if (nums[left] < nums[mid]) {
// right < left < mid
swap(nums, left, right);
return;
}
// right < left && mid <= left
if (nums[right] < nums[mid]) {
// right < mid <= left
swap(nums, mid, right);
return;
}
// mid <= right < left
return;
}
}
public static void swap(int[] nums, int i, int j) {
final int tmp = nums[i];
nums[i] = nums[j];
nums[j] = tmp;
}
}
Python3代码
class Solution(object):
def getLeastNumbers_Solution(self, input, k):
if len(input) == 0 or k <= 0:
return []
if len(input) <= k:
input.sort()
return input
self.topK(input, k);
input = input[:k]
input.sort()
return input
def topK(self, nums, k):
if k <= 0 or len(nums) <= k:
return
left = 0
right = len(nums) - 1
while True:
self.moveMedianToRight(nums, left, right)
pivot = nums[right]
i = left
j = right
while i < j:
while i < j and nums[i] <= pivot:
i += 1
while i < j and nums[j] >= pivot:
j -= 1
nums[i], nums[j] = nums[j], nums[i]
nums[i], nums[right] = nums[right], nums[i]
if i + 1 < k:
left = i + 1
elif i > k:
right = i - 1
else:
return
def moveMedianToRight(self, nums, left, right):
mid = left + ((right - left) >> 1)
if nums[left] <= nums[right]:
if nums[right] <= nums[mid]:
# left <= right <= mid
return
# left <= right && mid < right
if nums[left] < nums[mid]:
# left < mid < right
nums[mid], nums[right] = nums[right], nums[mid]
return
# mid <= left <= right
nums[left], nums[right] = nums[right], nums[left]
return
else:
# right < left
if nums[mid] <= nums[right]:
# mid <= right < left
return
# right < left && right < mid
if nums[left] < nums[mid]:
# right < left < mid
nums[left], nums[right] = nums[right], nums[left]
return
# right < mid <= left
nums[mid], nums[right] = nums[right], nums[mid]
return