分治法 + 归并排序
这个代码有点颠覆我的三观, 为什么呢?
1. 先根据x坐标排序, 然后根据y坐标排序
2. 计算最邻近距离的函数与归并排序有机结合了.
我第一眼看代码的时候, 非常好奇为什么可以直接归并了, 好像并没有说左区间的y, 右区间的y已经排过序了啊.
后来发现最邻近点函数是递归函数, 在之前的左区间递归, 右区间递归中, 排好了序.
时间复杂度
平均时间复杂度为O(N*log(N))
最差时间复杂度为O(N^2)
最差情况举例: 核电站全部在左边, 特工全部在右边.
结合代码来说:
ans = min(dp(l, mid), dp(mid + 1, r))
这部分的ans是正无穷大, 并不会在后面中起到筛选点的作用, 即, points_star的长度为N.
下面有个二重循环计算.
for i in range(len(points_star)):
for j in range(i + 1, len(points_star)):
if points_star[j][1] - points_star[i][1] >= ans: break
ans = min(ans, dist(points_star[i], points_star[j]))
因此, 最差情况下, 时间复杂度是O(N^2)
好消息是: 测试样例里面不存在这种最差情况.
运行时间为1056 ms
快于我的预期, 因为给C++的时限是5秒, 而这个是Python代码.
Python 代码
def dist(p1, p2):
# 相同类型的点, 不应该计算距离, 定义为无穷大
if p1[2] == p2[2]:
return float("inf")
return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5
def dp(l, r):
# 边界条件: 当区间个数小于等于1的时候, 并不存在两个点的距离
# 因此, 定义为无穷大, 目的是为了不刷新答案.
if l >= r:
return float("inf")
mid = (l + r) // 2
# 立即求出x坐标的中位值, 因为之后会对points按照y进行重新排序.
xmid = points[mid][0]
# 分治, 求出左区间众点之间的最小距离与右区间众点之间的最小距离
ans = min(dp(l, mid), dp(mid + 1, r))
# 基于y坐标, 通过归并排序算法进行排序
# 之前递归已经对左区间与右区间部分以及排好序了, 现在只要进行归并即可.
tmp = []
i, j = l, mid + 1
while i <= mid and j <= r:
if points[i][1] <= points[j][1]:
tmp.append(points[i])
i += 1
else:
tmp.append(points[j])
j += 1
if i <= mid:
tmp.extend(points[i : mid + 1])
else:
tmp.extend(points[j : r + 1])
# 更新[l, r]区间内的points, 即, 根据y坐标进行了重新排序
for i in range(l, r + 1):
points[i] = tmp[i - l]
# 现在考虑求出左区间与右区间的点各取一个点的最短距离.
# 很明显, 这些点必然在两个区间分界线附近.
# 找到x坐标介于xmid - ans与xmid + ans之间(开区间)的点
points_star = []
for p in tmp:
if xmid - ans < p[0] < xmid + ans:
points_star.append(p)
# 现在考虑y坐标, 对于每个points_star里面点(x, y), 只需要考虑y坐标区间(y - ans, y + ans)
# 为了避免重复计算, 只考虑y坐标区间(y, y + ans)
# 因为点已经根据y坐标排过序了, 因此, 序号更大的点, y坐标更大(或者相等)
# 为了简化编程, 没有让两个点分别为左区间和右区间(很明显, 如果在同一个区间内内的话, 并不会更新ans,
# 属于多余运算, 但是, 这点多余运算, 在整个程序中微不足道, 因此, 就不必优化了)
for i in range(len(points_star)):
for j in range(i + 1, len(points_star)):
if points_star[j][1] - points_star[i][1] >= ans: break
ans = min(ans, dist(points_star[i], points_star[j]))
return ans
T = int(input())
for _ in range(T):
N = int(input())
points = []
for i in range(N):
x, y = map(int, input().split())
points.append((x, y, 1))
for i in range(N):
x, y = map(int, input().split())
points.append((x, y, 2))
points.sort()
print(f"{dp(0, 2 * N - 1):.3f}")
2023-05-10 测试,TLE.....