ガウス混合モデルによる密度推定

Beginner

This tutorial is from open-source community. Access the source code

はじめに

この実験では、scikit-learn ライブラリを使ってガウス混合データセットを生成します。その後、ガウス混合モデル(GMM)をデータセットに適合させ、ガウス混合の密度推定をプロットします。GMM は、データセットの確率分布をモデル化し推定するために使用できます。

VM のヒント

VM の起動が完了したら、左上隅をクリックしてノートブックタブに切り替え、Jupyter Notebook を使って練習しましょう。

時々、Jupyter Notebook が読み込み終了するまで数秒待つ必要がある場合があります。Jupyter Notebook の制限により、操作の検証を自動化することはできません。

学習中に問題に遭遇した場合は、Labby にお問い合わせください。セッション後にフィードバックを提供してください。すぐに問題を解決いたします。

ライブラリのインポート

必要なライブラリをインポートして始めましょう。数値計算には NumPy を、可視化には Matplotlib を使用します。また、scikit-learn ライブラリから GaussianMixture クラスをインポートして、GMM をデータセットに適合させます。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import mixture

データの生成

次に、2 つのコンポーネントを持つガウス混合データセットを生成します。(20, 20) を中心としたシフトされたガウスデータセットと、ゼロ中心の伸ばされたガウスデータセットを作成します。その後、2 つのデータセットを連結して最終的な学習セットにします。

n_samples = 300

## 乱数のシードを設定
np.random.seed(0)

## (20, 20) を中心とした球状のデータを生成
shifted_gaussian = np.random.randn(n_samples, 2) + np.array([20, 20])

## ゼロ中心の伸ばされたガウスデータを生成
C = np.array([[0.0, -0.7], [3.5, 0.7]])
stretched_gaussian = np.dot(np.random.randn(n_samples, 2), C)

## 2 つのデータセットを連結して最終的な学習セットにする
X_train = np.vstack([shifted_gaussian, stretched_gaussian])

ガウス混合モデルの適合

ここでは、scikit-learn の GaussianMixture クラスを使って GMM をデータセットに適合させます。コンポーネント数を 2 に設定し、共分散のタイプを「完全」に設定します。

## 2 つのコンポーネントを持つガウス混合モデルを適合させる
clf = mixture.GaussianMixture(n_components=2, covariance_type="full")
clf.fit(X_train)

密度推定のプロット

次に、ガウス混合の密度推定をプロットします。データセットの範囲にわたって点のメッシュグリッドを作成し、各点に対して GMM によって予測される負の対数尤度を計算します。その後、予測されたスコアを等高線プロットとして表示し、学習データを散布図で表示します。

## モデルによって予測されたスコアを等高線プロットとして表示する
x = np.linspace(-20.0, 30.0)
y = np.linspace(-20.0, 40.0)
X, Y = np.meshgrid(x, y)
XX = np.array([X.ravel(), Y.ravel()]).T
Z = -clf.score_samples(XX)
Z = Z.reshape(X.shape)

CS = plt.contour(
    X, Y, Z, norm=LogNorm(vmin=1.0, vmax=1000.0), levels=np.logspace(0, 3, 10)
)
CB = plt.colorbar(CS, shrink=0.8, extend="both")
plt.scatter(X_train[:, 0], X_train[:, 1], 0.8)

plt.title("Density Estimation with Gaussian Mixture Models")
plt.axis("tight")
plt.show()

まとめ

この実験では、scikit-learn を使ってガウス混合データセットを生成し、そのデータセットに GMM を適合させる方法を学びました。また、等高線プロットを使ってガウス混合の密度推定をプロットしました。