kmeans.py
"""
手写kmeans
"""
import numpy as np
from matplotlib import pyplot as plt
from sklearn.datasets import make_blobs
import typing
class KMeans:
def __init__(self, k: int):
self.k = k
self._centers = None
def fit(self, nda: np.ndarray, n_iters=10, callback: typing.Callable = None):
n_features = nda.shape[1]
centers = self.random_centers(self.k, n_features)
for i in range(n_iters):
labels = self.assign(centers, nda)
centers = self.update(nda, labels, self.k)
if callback:
callback(nda, labels, i)
def predict(self, nda: np.ndarray):
return self.assign(self.centers, nda)
@property
def centers(self):
if not self._centers:
raise AttributeError("Call 'fit' before reference to centers.")
return self._centers
@staticmethod
def random_centers(k, n_features):
return np.random.random((k, n_features))
@staticmethod
def assign(centers, nda):
n = nda.shape[0]
labels = np.empty(n)
for i, arr in enumerate(nda):
labels[i] = KMeans.nearest_center(centers, arr)
return labels
@staticmethod
def update(nda, labels, k):
centers = np.empty(k)
for i in range(k):
center = KMeans.cal_center(nda, labels, i)
centers[i] = center
return centers
@staticmethod
def distance(arr1, arr2):
return np.sum((arr1 - arr2) ** 2)
@staticmethod
def cal_center(nda, labels, i):
return np.mean(nda[labels == i])
@staticmethod
def nearest_center(centers, nda):
j = -1
min_dis = np.PINF
for i, center in enumerate(centers):
dis = KMeans.distance(center, nda)
if dis < min_dis:
min_dis = dis
j = i
return j
def my_plot(nda, labels, i):
if i % 10 == 0:
plt.scatter(nda[:, 0], nda[:, 1], c=labels)
plt.title("i = %s" % i)
plt.savefig("%s.png" % i)
def main():
X, y = make_blobs(n_samples=1000, n_features=2, centers=[[-1, -1], [0, 0], [1, 1], [2, 2]],
cluster_std=[0.4, 0.2, 0.2, 0.2],
random_state=9)
kmeans = KMeans(4)
kmeans.fit(X, n_iters=40, callback=my_plot)
if __name__ == '__main__':
main()
运行结果:
image
image
image
image