Iris データセットにおけるガウス過程分類

Beginner

This tutorial is from open-source community. Access the source code

はじめに

この実験では、Iris データセットでガウス過程分類 (GPC) をどのように使用するかを検討します。Iris データセットは、3 種類の異なる Iris 花の花弁と花萼の長さと幅に関する情報が含まれる有名なデータセットです。我々は、分類タスクの確率的アプローチである GPC を実装するために scikit-learn を使用します。

VM のヒント

VM の起動が完了した後、左上隅をクリックしてノートブックタブに切り替え、Jupyter Notebook を使って練習しましょう。

時々、Jupyter Notebook が読み込み終了するまで数秒待つ必要がある場合があります。Jupyter Notebook の制限により、操作の検証は自動化できません。

学習中に問題に直面した場合は、Labby にお尋ねください。セッション後にフィードバックを提供してください。すぐに問題を解決いたします。

必要なライブラリとデータセットのインポート

まず、必要なライブラリをインポートし、scikit-learn から Iris データセットを読み込みます。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF

iris = datasets.load_iris()
X = iris.data[:, :2]  ## 最初の 2 つの特徴のみを取ります。
y = np.array(iris.target, dtype=int)

カーネル関数の定義

次に、カーネル関数を定義します。この例では、Radial Basis Function (RBF) カーネルを使用します。RBF カーネルの 2 つのバージョンを定義します。等方性バージョンと異方性バージョンです。

kernel = 1.0 * RBF([1.0])
gpc_rbf_isotropic = GaussianProcessClassifier(kernel=kernel).fit(X, y)

kernel = 1.0 * RBF([1.0, 1.0])
gpc_rbf_anisotropic = GaussianProcessClassifier(kernel=kernel).fit(X, y)

プロット用のメッシュ作成

次に、プロット用のメッシュを作成します。このメッシュは、メッシュ上の各点に対する予測確率をプロットするために使用されます。また、メッシュの刻み幅も定義します。

h = 0.02  ## メッシュの刻み幅

x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

予測確率のプロット

次に、メッシュ上の各点の予測確率をプロットします。2 つのサブプロットを作成します。1 つは等方性 RBF カーネル用で、もう 1 つは異方性 RBF カーネル用です。メッシュ上の各点の予測確率を取得するために predict_proba メソッドを使用します。その後、予測確率をメッシュ上のカラープロットとしてプロットします。また、各種の Iris 花の学習ポイントもプロットします。

titles = ["Isotropic RBF", "Anisotropic RBF"]
plt.figure(figsize=(10, 5))
for i, clf in enumerate((gpc_rbf_isotropic, gpc_rbf_anisotropic)):
    ## 予測確率をプロットします。そのために、メッシュ [x_min, m_max]x[y_min, y_max] の各点に色を割り当てます。
    plt.subplot(1, 2, i + 1)

    Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])

    ## 結果をカラープロットに入れます
    Z = Z.reshape((xx.shape[0], xx.shape[1], 3))
    plt.imshow(Z, extent=(x_min, x_max, y_min, y_max), origin="lower")

    ## 学習ポイントもプロットします
    plt.scatter(X[:, 0], X[:, 1], c=np.array(["r", "g", "b"])[y], edgecolors=(0, 0, 0))
    plt.xlabel("Sepal length")
    plt.ylabel("Sepal width")
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.xticks(())
    plt.yticks(())
    plt.title(
        "%s, LML: %.3f" % (titles[i], clf.log_marginal_likelihood(clf.kernel_.theta))
    )

plt.tight_layout()
plt.show()

まとめ

この実験では、scikit-learn を使って Iris データセットでガウス過程分類 (GPC) をどのように使用するかを検討しました。Radial Basis Function (RBF) カーネルの 2 つのバージョン、すなわち等方性バージョンと異方性バージョンを定義しました。その後、メッシュ上の各点の予測確率をプロットするためのメッシュを作成し、予測確率をメッシュ上のカラープロットとしてプロットしました。最後に、各種の Iris 花の学習ポイントをプロットしました。