使用高斯混合模型进行密度估计

Machine LearningMachine LearningBeginner
立即练习

This tutorial is from open-source community. Access the source code

💡 本教程由 AI 辅助翻译自英文原版。如需查看原文,您可以 切换至英文原版

简介

在本实验中,我们将使用 scikit-learn 库生成一个高斯混合数据集。然后,我们将对该数据集拟合一个高斯混合模型(Gaussian Mixture Model,GMM),并绘制高斯混合的密度估计图。GMM 可用于对数据集的概率分布进行建模和估计。

虚拟机使用提示

虚拟机启动完成后,点击左上角切换到笔记本标签页,以访问 Jupyter Notebook 进行练习。

有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。

如果你在学习过程中遇到问题,随时向 Labby 提问。课程结束后提供反馈,我们会及时为你解决问题。


Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL ml(("Machine Learning")) -.-> ml/FrameworkandSoftwareGroup(["Framework and Software"]) ml/FrameworkandSoftwareGroup -.-> ml/sklearn("scikit-learn") subgraph Lab Skills ml/sklearn -.-> lab-49136{{"使用高斯混合模型进行密度估计"}} end

导入库

我们将首先导入必要的库:用于数值计算的 NumPy 和用于可视化的 Matplotlib。我们还将从 scikit-learn 库中导入 GaussianMixture 类,以便将 GMM 拟合到我们的数据集。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import mixture

生成数据

接下来,我们将生成一个包含两个组件的高斯混合数据集。我们将创建一个以 (20, 20) 为中心的平移高斯数据集和一个零中心拉伸高斯数据集。然后,我们将把这两个数据集连接成最终的训练集。

n_samples = 300

## 生成随机样本,两个组件
np.random.seed(0)

## 生成以 (20, 20) 为中心的球形数据
shifted_gaussian = np.random.randn(n_samples, 2) + np.array([20, 20])

## 生成零中心拉伸高斯数据
C = np.array([[0.0, -0.7], [3.5, 0.7]])
stretched_gaussian = np.dot(np.random.randn(n_samples, 2), C)

## 将两个数据集连接成最终的训练集
X_train = np.vstack([shifted_gaussian, stretched_gaussian])

拟合高斯混合模型

现在,我们将使用 scikit-learn 中的 GaussianMixture 类对数据集拟合一个 GMM。我们将把组件数量设置为 2,协方差类型设置为“full”。

## 拟合一个具有两个组件的高斯混合模型
clf = mixture.GaussianMixture(n_components=2, covariance_type="full")
clf.fit(X_train)

绘制密度估计图

现在我们将绘制高斯混合的密度估计图。我们将在数据集的范围内创建一个点的网格,并计算 GMM 对每个点预测的负对数似然。然后,我们将把预测分数显示为等高线图,并将训练数据绘制成散点图。

## 将模型预测的分数显示为等高线图
x = np.linspace(-20.0, 30.0)
y = np.linspace(-20.0, 40.0)
X, Y = np.meshgrid(x, y)
XX = np.array([X.ravel(), Y.ravel()]).T
Z = -clf.score_samples(XX)
Z = Z.reshape(X.shape)

CS = plt.contour(
    X, Y, Z, norm=LogNorm(vmin=1.0, vmax=1000.0), levels=np.logspace(0, 3, 10)
)
CB = plt.colorbar(CS, shrink=0.8, extend="both")
plt.scatter(X_train[:, 0], X_train[:, 1], 0.8)

plt.title("Density Estimation with Gaussian Mixture Models")
plt.axis("tight")
plt.show()

总结

在这个实验中,我们学习了如何使用 scikit-learn 生成高斯混合数据集,并将高斯混合模型(GMM)拟合到该数据集上。我们还使用等高线图绘制了高斯混合的密度估计。