高斯混合模型协方差

Beginner

This tutorial is from open-source community. Access the source code

简介

本教程演示了高斯混合模型(GMM)中不同协方差类型的使用。GMM 常用于聚类,我们可以将得到的聚类与数据集中的实际类别进行比较。为了使这种比较有效,我们用训练集类别的均值来初始化高斯分布的均值。我们在鸢尾花数据集上使用多种 GMM 协方差类型,在训练数据和留出的测试数据上绘制预测标签。我们按照性能递增的顺序比较具有球形、对角、全协方差和绑定协方差矩阵的 GMM。

虽然一般认为全协方差通常表现最佳,但它在小数据集上容易过拟合,对留出的测试数据泛化能力不佳。

在图中,训练数据用点表示,而测试数据用叉表示。鸢尾花数据集是四维的。这里只显示了前两个维度,因此有些点在其他维度上是分开的。

虚拟机使用提示

虚拟机启动完成后,点击左上角切换到“笔记本”标签页,以访问 Jupyter Notebook 进行练习。

有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。

如果你在学习过程中遇到问题,随时向 Labby 提问。课程结束后提供反馈,我们会及时为你解决问题。

导入库

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import StratifiedKFold

加载鸢尾花数据集

iris = datasets.load_iris()

准备训练数据和测试数据

skf = StratifiedKFold(n_splits=4)
train_index, test_index = next(iter(skf.split(iris.data, iris.target)))

X_train = iris.data[train_index]
y_train = iris.target[train_index]
X_test = iris.data[test_index]
y_test = iris.target[test_index]

为不同协方差类型设置高斯混合模型(GMM)估计器

colors = ["navy", "turquoise", "darkorange"]
n_classes = len(np.unique(y_train))

estimators = {
    cov_type: GaussianMixture(
        n_components=n_classes, covariance_type=cov_type, max_iter=20, random_state=0
    )
    for cov_type in ["spherical", "diag", "tied", "full"]
}

n_estimators = len(estimators)

定义一个函数来绘制高斯混合模型(GMM)的椭圆

def make_ellipses(gmm, ax):
    for n, color in enumerate(colors):
        if gmm.covariance_type == "full":
            covariances = gmm.covariances_[n][:2, :2]
        elif gmm.covariance_type == "tied":
            covariances = gmm.covariances_[:2, :2]
        elif gmm.covariance_type == "diag":
            covariances = np.diag(gmm.covariances_[n][:2])
        elif gmm.covariance_type == "spherical":
            covariances = np.eye(gmm.means_.shape[1]) * gmm.covariances_[n]
        v, w = np.linalg.eigh(covariances)
        u = w[0] / np.linalg.norm(w[0])
        angle = np.arctan2(u[1], u[0])
        angle = 180 * angle / np.pi
        v = 2.0 * np.sqrt(2.0) * np.sqrt(v)
        ell = mpl.patches.Ellipse(
            gmm.means_[n, :2], v[0], v[1], angle=180 + angle, color=color
        )
        ell.set_clip_box(ax.bbox)
        ell.set_alpha(0.5)
        ax.add_artist(ell)
        ax.set_aspect("equal", "datalim")

绘制不同协方差类型的高斯混合模型(GMM)

plt.figure(figsize=(3 * n_estimators // 2, 6))
plt.subplots_adjust(
    bottom=0.01, top=0.95, hspace=0.15, wspace=0.05, left=0.01, right=0.99
)

for index, (name, estimator) in enumerate(estimators.items()):
    estimator.means_init = np.array(
        [X_train[y_train == i].mean(axis=0) for i in range(n_classes)]
    )

    estimator.fit(X_train)

    h = plt.subplot(2, n_estimators // 2, index + 1)
    make_ellipses(estimator, h)

    for n, color in enumerate(colors):
        data = iris.data[iris.target == n]
        plt.scatter(
            data[:, 0], data[:, 1], s=0.8, color=color, label=iris.target_names[n]
        )

    for n, color in enumerate(colors):
        data = X_test[y_test == n]
        plt.scatter(data[:, 0], data[:, 1], marker="x", color=color)

    y_train_pred = estimator.predict(X_train)
    train_accuracy = np.mean(y_train_pred.ravel() == y_train.ravel()) * 100
    plt.text(0.05, 0.9, "Train accuracy: %.1f" % train_accuracy, transform=h.transAxes)

    y_test_pred = estimator.predict(X_test)
    test_accuracy = np.mean(y_test_pred.ravel() == y_test.ravel()) * 100
    plt.text(0.05, 0.8, "Test accuracy: %.1f" % test_accuracy, transform=h.transAxes)

    plt.xticks(())
    plt.yticks(())
    plt.title(name)

plt.legend(scatterpoints=1, loc="lower right", prop=dict(size=12))
plt.show()

总结

本教程展示了如何在 Python 中对高斯混合模型(GMM)使用不同的协方差类型。我们以鸢尾花数据集为例,比较了具有球形、对角、完全和绑定协方差矩阵的高斯混合模型,其性能按升序排列。我们在训练数据和留出的测试数据上绘制了预测标签,并表明完全协方差在小数据集上容易过拟合,并且对留出的测试数据的泛化能力不佳。