import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
dataset = pd.read_csv('watermelon.csv', delimiter=",")
data = dataset.values
print(dataset)
def distance(x1, x2):
return sum((x1-x2)**2)
def Kmeans(D,K,maxIter):
m, n = np.shape(D)
if K >= m:
return D
initSet = set()
curK = K
while(curK > 0):
randomInt = random.randint(0, m-1)
if randomInt not in initSet:
curK -= 1
initSet.add(randomInt)
U = D[list(initSet), :]
C = np.zeros(m)
curIter = maxIter
while curIter > 0:
curIter -= 1
for i in range(m):
p = 0
minDistance = distance(D[i], U[0])
for j in range(1, K):
if distance(D[i], U[j]) < minDistance:
p = j
minDistance = distance(D[i], U[j])
C[i] = p
newU = np.zeros((K, n))
cnt = np.zeros(K)
for i in range(m):
newU[int(C[i])] += D[i]
cnt[int(C[i])] += 1
changed = 0
for i in range(K):
newU[i] /= cnt[i]
for j in range(n):
if U[i, j] != newU[i, j]:
changed = 1
U[i, j] = newU[i, j]
if changed == 0:
return U, C, maxIter-curIter
return U, C, maxIter-curIter
U, C, iter = Kmeans(data,3,20)
f1 = plt.figure(1)
plt.title('watermelon')
plt.xlabel('density')
plt.ylabel('ratio')
plt.scatter(data[:, 0], data[:, 1], marker='o', color='g', s=50)
plt.scatter(U[:, 0], U[:, 1], marker='o', color='r', s=100)
m, n = np.shape(data)
for i in range(m):
plt.plot([data[i, 0], U[int(C[i]), 0]], [data[i, 1], U[int(C[i]), 1]], "c--", linewidth=0.3)
plt.show()
时间复杂度分析:
1.初始化质心(O(k * m)) – 随机选择k个样本作为初始质心;最坏情况下,可能需要遍历m个样本才能找到K个不重复的样本。
2.聚类过程(O(maxIter * m * K * n)) – 外层循环最多执行maxIter次;对于每个样本(一共m个),需要计算其到K个质心的距离,每次距离计算需要O(n)时间。
3.更新质心(O(m * n + K * n)) - 遍历所有样本(m 个),将样本累加到对应类别的质心,时间复杂度为O(m * n);更新每个质心(K 个),时间复杂度为O(K * n)。
综上:将上述部分的时间复杂度相加,得到整个 K-Means 算法的时间复杂度为:
O(K * m + maxIter * m * K * n + maxIter * (m * n + K * n))
附录:Watermalon数据集(摘选),保存为Watermelon.csv
density,ratio
0.697,0.460
0.774,0.376
0.634,0.264
0.608,0.318
0.556,0.215
0.403,0.237
0.481,0.149
0.437,0.211
0.666,0.091
0.243,0.267
0.245,0.057
0.343,0.099
0.639,0.161
0.657,0.198
0.360,0.370
0.593,0.042
0.719,0.103
0.359,0.188
0.339,0.241
0.282,0.257
0.748,0.232
0.714,0.346
0.483,0.312
0.478,0.437
0.525,0.369
0.751,0.489
0.532,0.472
0.473,0.376
0.725,0.445
0.446,0.459