XOR データセットにおけるガウス過程分類

Beginner

This tutorial is from open-source community. Access the source code

はじめに

この実験では、scikit-learn を使って XOR データセットに対するガウス過程分類 (GPC) の例を見ていきます。定常的で等方的なカーネル (RBF) と非定常的なカーネル (DotProduct) を使って得られる結果を比較します。

VM のヒント

VM の起動が完了したら、左上隅をクリックして ノートブック タブに切り替え、Jupyter Notebook を使って練習しましょう。

Jupyter Notebook の読み込みには数秒かかる場合があります。Jupyter Notebook の制限により、操作の検証は自動化できません。

学習中に問題がある場合は、Labby にお問い合わせください。セッション後にフィードバックを提供してください。すぐに問題を解決いたします。

ライブラリのインポート

このステップでは、この実験に必要なライブラリをインポートします。

import numpy as np
import matplotlib.pyplot as plt

from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF, DotProduct

XOR データセットの作成

このステップでは、numpy を使って XOR データセットを作成します。入力特徴に基づいてラベルを作成するために logical_xor 関数を使います。

xx, yy = np.meshgrid(np.linspace(-3, 3, 50), np.linspace(-3, 3, 50))
rng = np.random.RandomState(0)
X = rng.randn(200, 2)
Y = np.logical_xor(X[:, 0] > 0, X[:, 1] > 0)

モデルのフィッティング

このステップでは、ガウス過程分類器をデータセットにフィッティングさせます。比較のために 2 つの異なるカーネル - RBF と DotProduct を使います。

plt.figure(figsize=(10, 5))
kernels = [1.0 * RBF(length_scale=1.15), 1.0 * DotProduct(sigma_0=1.0) ** 2]
for i, kernel in enumerate(kernels):
    clf = GaussianProcessClassifier(kernel=kernel, warm_start=True).fit(X, Y)

    ## plot the decision function for each datapoint on the grid
    Z = clf.predict_proba(np.vstack((xx.ravel(), yy.ravel())).T)[:, 1]
    Z = Z.reshape(xx.shape)

    plt.subplot(1, 2, i + 1)
    image = plt.imshow(
        Z,
        interpolation="nearest",
        extent=(xx.min(), xx.max(), yy.min(), yy.max()),
        aspect="auto",
        origin="lower",
        cmap=plt.cm.PuOr_r,
    )
    contours = plt.contour(xx, yy, Z, levels=[0.5], linewidths=2, colors=["k"])
    plt.scatter(X[:, 0], X[:, 1], s=30, c=Y, cmap=plt.cm.Paired, edgecolors=(0, 0, 0))
    plt.xticks(())
    plt.yticks(())
    plt.axis([-3, 3, -3, 3])
    plt.colorbar(image)
    plt.title(
        "%s\n Log-Marginal-Likelihood:%.3f"
        % (clf.kernel_, clf.log_marginal_likelihood(clf.kernel_.theta)),
        fontsize=12,
    )

plt.tight_layout()
plt.show()

結果の可視化

このステップでは、モデルをフィッティングさせて得られた結果を可視化します。グリッド上の各データポイントの決定関数をプロットし、入力特徴に対する散布図を描きます。

まとめ

この実験では、scikit - learn を使って XOR データセットに対するガウス過程分類 (GPC) の例を見てきました。定常的で等方性のあるカーネル (RBF) と非定常的なカーネル (DotProduct) を使って得られた結果を比較しました。