Gaussian Mixture Model Covariances

Machine LearningMachine LearningBeginner
Practice Now

This tutorial is from open-source community. Access the source code

Introduction

This tutorial demonstrates the use of different covariance types for Gaussian mixture models (GMMs). GMMs are often used for clustering, and we can compare the obtained clusters with the actual classes from the dataset. We initialize the means of the Gaussians with the means of the classes from the training set to make this comparison valid. We plot predicted labels on both training and held-out test data using a variety of GMM covariance types on the iris dataset. We compare GMMs with spherical, diagonal, full, and tied covariance matrices in increasing order of performance.

Although one would expect full covariance to perform best in general, it is prone to overfitting on small datasets and does not generalize well to held-out test data.

On the plots, train data is shown as dots, while test data is shown as crosses. The iris dataset is four-dimensional. Only the first two dimensions are shown here, and thus some points are separated in other dimensions.

VM Tips

After the VM startup is done, click the top left corner to switch to the Notebook tab to access Jupyter Notebook for practice.

Sometimes, you may need to wait a few seconds for Jupyter Notebook to finish loading. The validation of operations cannot be automated because of limitations in Jupyter Notebook.

If you face issues during learning, feel free to ask Labby. Provide feedback after the session, and we will promptly resolve the problem for you.


Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL sklearn(("`Sklearn`")) -.-> sklearn/CoreModelsandAlgorithmsGroup(["`Core Models and Algorithms`"]) sklearn(("`Sklearn`")) -.-> sklearn/ModelSelectionandEvaluationGroup(["`Model Selection and Evaluation`"]) ml(("`Machine Learning`")) -.-> ml/FrameworkandSoftwareGroup(["`Framework and Software`"]) sklearn/CoreModelsandAlgorithmsGroup -.-> sklearn/mixture("`Gaussian Mixture Models`") sklearn/ModelSelectionandEvaluationGroup -.-> sklearn/model_selection("`Model Selection`") ml/FrameworkandSoftwareGroup -.-> ml/sklearn("`scikit-learn`") subgraph Lab Skills sklearn/mixture -.-> lab-49134{{"`Gaussian Mixture Model Covariances`"}} sklearn/model_selection -.-> lab-49134{{"`Gaussian Mixture Model Covariances`"}} ml/sklearn -.-> lab-49134{{"`Gaussian Mixture Model Covariances`"}} end

Import Libraries

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import StratifiedKFold

Load the Iris dataset

iris = datasets.load_iris()

Prepare training and test data

skf = StratifiedKFold(n_splits=4)
train_index, test_index = next(iter(skf.split(iris.data, iris.target)))

X_train = iris.data[train_index]
y_train = iris.target[train_index]
X_test = iris.data[test_index]
y_test = iris.target[test_index]

Set up GMM estimators for different covariance types

colors = ["navy", "turquoise", "darkorange"]
n_classes = len(np.unique(y_train))

estimators = {
    cov_type: GaussianMixture(
        n_components=n_classes, covariance_type=cov_type, max_iter=20, random_state=0
    )
    for cov_type in ["spherical", "diag", "tied", "full"]
}

n_estimators = len(estimators)

Define a function to plot ellipses for GMMs

def make_ellipses(gmm, ax):
    for n, color in enumerate(colors):
        if gmm.covariance_type == "full":
            covariances = gmm.covariances_[n][:2, :2]
        elif gmm.covariance_type == "tied":
            covariances = gmm.covariances_[:2, :2]
        elif gmm.covariance_type == "diag":
            covariances = np.diag(gmm.covariances_[n][:2])
        elif gmm.covariance_type == "spherical":
            covariances = np.eye(gmm.means_.shape[1]) * gmm.covariances_[n]
        v, w = np.linalg.eigh(covariances)
        u = w[0] / np.linalg.norm(w[0])
        angle = np.arctan2(u[1], u[0])
        angle = 180 * angle / np.pi
        v = 2.0 * np.sqrt(2.0) * np.sqrt(v)
        ell = mpl.patches.Ellipse(
            gmm.means_[n, :2], v[0], v[1], angle=180 + angle, color=color
        )
        ell.set_clip_box(ax.bbox)
        ell.set_alpha(0.5)
        ax.add_artist(ell)
        ax.set_aspect("equal", "datalim")

Plot GMMs for different covariance types

plt.figure(figsize=(3 * n_estimators // 2, 6))
plt.subplots_adjust(
    bottom=0.01, top=0.95, hspace=0.15, wspace=0.05, left=0.01, right=0.99
)

for index, (name, estimator) in enumerate(estimators.items()):
    estimator.means_init = np.array(
        [X_train[y_train == i].mean(axis=0) for i in range(n_classes)]
    )

    estimator.fit(X_train)

    h = plt.subplot(2, n_estimators // 2, index + 1)
    make_ellipses(estimator, h)

    for n, color in enumerate(colors):
        data = iris.data[iris.target == n]
        plt.scatter(
            data[:, 0], data[:, 1], s=0.8, color=color, label=iris.target_names[n]
        )

    for n, color in enumerate(colors):
        data = X_test[y_test == n]
        plt.scatter(data[:, 0], data[:, 1], marker="x", color=color)

    y_train_pred = estimator.predict(X_train)
    train_accuracy = np.mean(y_train_pred.ravel() == y_train.ravel()) * 100
    plt.text(0.05, 0.9, "Train accuracy: %.1f" % train_accuracy, transform=h.transAxes)

    y_test_pred = estimator.predict(X_test)
    test_accuracy = np.mean(y_test_pred.ravel() == y_test.ravel()) * 100
    plt.text(0.05, 0.8, "Test accuracy: %.1f" % test_accuracy, transform=h.transAxes)

    plt.xticks(())
    plt.yticks(())
    plt.title(name)

plt.legend(scatterpoints=1, loc="lower right", prop=dict(size=12))
plt.show()

Summary

This tutorial demonstrated the use of different covariance types for Gaussian mixture models (GMMs) in Python. We used the Iris dataset as an example and compared GMMs with spherical, diagonal, full, and tied covariance matrices in increasing order of performance. We plotted predicted labels on both training and held-out test data and showed that full covariance is prone to overfitting on small datasets and does not generalize well to held-out test data.

Other Machine Learning Tutorials you may like