2024-04-03 高斯混合模型

k-means模型没有对边界附近的点的聚类分配的概率或者不确定性进行度量,显得不够通用。而且,簇模型的形状只能说圆形不够灵活(椭圆形)。k-means的两个缺点:

  • 类的形状缺少灵活性。
  • 缺少簇分配的概率。、

高斯混合模型(Gaussian mixture model,GMM)使用多维高斯分布的混合对输入数据进行建模。predict_prob方法给出任意点属于某个簇的概率。
可以使用全协方差拟合数据。

rng = np.random.RandomState(13)
X_stretched = np.dot(X, rng.randn(2, 2))

gmm = GaussianMixture(n_components=4, covariance_type='full', random_state=42)
plot_gmm(gmm, X_stretched)
e919f0f24fc648dba7bad5a556dfbe10.png

GMM用作密度估计

GMM本质上是一个密度估计算法,是描述数据分布的生成概率模型。
也就是说,GMM为我们提供了生成与输入数据分布类似的新随机数据的方法。作为一种非常方便的建模方法,GMM可以为数据估计出任意维度的随机分布。

使用赤池信息准则(Akaike information criterion,AIC)、贝叶斯信息准则(Bayesian information criterion, BIC)来找最优的n_components。

from sklearn.datasets import make_moons
Xmoon, ymoon = make_moons(200, noise=.05, random_state=0)
plt.scatter(Xmoon[:, 0], Xmoon[:, 1]);

n_components = np.arange(1, 21)
models = [GaussianMixture(n, covariance_type='full', random_state=0).fit(Xmoon)
          for n in n_components]
plt.figure()
plt.plot(n_components, [m.bic(Xmoon) for m in models], label='BIC')
plt.plot(n_components, [m.aic(Xmoon) for m in models], label='AIC')
plt.legend(loc='best')
plt.xlabel('n_components');

案例:使用GMM生成新的数据

使用标准手写数字库生成新的手写数字。

  1. 使用PCA投影保留99.9%的方差,将维度从8\times8=64维降到41维。
  2. 使用AIC估计GMM的成分数量。
  3. 拟合数据,确认收敛,逆变换。
from sklearn.datasets import load_digits
digits = load_digits()
print(digits.data.shape)
def plot_digits(data):
    fig, ax = plt.subplots(10, 10, figsize=(8, 8),
                           subplot_kw=dict(xticks=[], yticks=[]))
    fig.subplots_adjust(hspace=0.05, wspace=0.05)
    for i, axi in enumerate(ax.flat):
        im = axi.imshow(data[i].reshape(8, 8), cmap='binary')
        im.set_clim(0, 16)
plot_digits(digits.data)

from sklearn.decomposition import PCA
pca = PCA(0.999,whiten=True)
data = pca.fit_transform(digits.data)
print(data.shape)

n_components = np.arange(50,210,10)
models = [GaussianMixture(n, covariance_type='full', random_state=0)
          for n in n_components]
aics = [model.fit(data).aic(data) for model in models]
plt.figure()
plt.plot(n_components, aics);

gmm = GaussianMixture(110, covariance_type='full', random_state=0)
gmm.fit(data)
print(gmm.converged_)

data_new,_ = gmm.sample(100)
print(data_new.shape)

digits_new = pca.inverse_transform(data_new)
plot_digits(digits_new)
297f11323c204301b55d5419b02cfaa9.png

参考:
[1]美 万托布拉斯 (VanderPlas, Jake).Python数据科学手册[M].人民邮电出版社,2018.

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容