最近的一个项目要用到mean-shift[0]算法,显然,首先是选择一个包含mean-shift算法的机器学习工具包,而且最好是开源的,因为后续我们可以根据需要来修改一些东西。
这里我们选择了python实现的开源机器学习工具包Scikit-learn[1.5],其GitHub链接为[2]。
我们从官方提供的demo[3]开始,
首先从相应的包(package)中导入要用到的模块(module)
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs
顾名思义,sklearn.cluster包含一些聚类(cluster)算法,而sklearn.datasets.samples_generator用于生成数据样本。
生成数据样本
centers = [[1, 1], [-1, -1], [1, -1]]
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)
然后对数据样本进行mean-shift分析
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
print("number of estimated clusters : %d" % n_clusters_)
estimate_bandwidth()用于生成mean-shift窗口的尺寸,其参数的意义为:从X中随机选取500个样本,计算每一对样本的距离,然后选取这些距离的0.2分位数作为返回值,显然当n_samples很大时,这个函数的计算量是很大的。
np.unique(labels)返回labels不同取值的个数,这里用于统计聚类后类别的个数。
MeanShift类的构造函数MeanShift()是重点,其原型为:
MeanShift(bandwidth=None, seeds=None, bin_seeding=False, min_bin_freq=1,cluster_all=True, n_jobs=1)
其参数的意义为:
bandwidth:float, Bandwidth used in the RBF(Radical Basis Function,径向基函数) kernel. If not given, the bandwidth is estimated using sklearn.cluster.estimate_bandwidth.
seeds:array, shape=[n_samples, n_features], Seeds used to initialize kernels. If not set, the seeds are calculated by clustering.get_bin_seeds with bandwidth as the grid size and default values for other parameters.
bin_seeding: boolean, If true, initial kernel locations are not locations of all points, but rather the location of the discretized version of points, where points are binned onto a grid whose coarseness(粒度) corresponds to the bandwidth. Setting this option to True will speed up the algorithm because fewer seeds will be initialized. Ignored if seeds argument is not None.
min_bin_freq: int, optional, To speed up the algorithm, accept only those bins with at least min_bin_freq points as seeds, default 1.
cluster_all: If true, then all points are clustered, even those orphans that are not within any kernel. Orphans are assigned to the nearest kernel. If false, then orphans are given cluster label -1.
n_jobs:The number of jobs to use for the computation. This works by computing each of the n_init runs in parallel. If -1 all CPUs are used. If 1 is given, no parallel computing code is used at all, which is useful for debugging. For n_jobs below -1, (n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one are used.
MeanShift类的其他常用函数以及属性:
**cluster_centers_ **: array, [n_clusters, n_features].Coordinates of cluster centers.
labels_ : Labels of each point.
fit(X):Perform clustering.
最后画出聚类的结果
# Plot result
import matplotlib.pyplot as plt
from itertools import cycle
plt.figure(1)
plt.clf()
colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
for k, col in zip(range(n_clusters_), colors):
my_members = labels == k
cluster_center = cluster_centers[k]
plt.plot(X[my_members, 0], X[my_members, 1], col + '.')
plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col,
markeredgecolor='k', markersize=14)
plt.title('Estimated number of clusters: %d' % n_clusters_)
plt.show()
参考资料
[0]Mean shift: A robust approach toward feature space analysis. D. Comaniciu and P. Meer, IEEE Transactions on Pattern Analysis and Machine Intelligence (2002)
[1.5]Scikit-learn: Machine Learning in Python. Pedregosa et al., JMLR 12, pp. 2825-2830, 2011
[1]tutorial: http://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html
[2]source: https://github.com/scikit-learn/scikit-learn
[3]demo: http://scikit-learn.org/stable/auto_examples/cluster/plot_mean_shift.html#sphx-glr-auto-examples-cluster-plot-mean-shift-py