最寄り重心分類

Beginner

This tutorial is from open-source community. Access the source code

はじめに

この実験では、Scikit-learn を使った最寄りの重心分類の実装方法を学びます。最寄りの重心分類は、各クラスの重心を計算し、その重心に最も近い新しいデータポイントに基づいて分類する、シンプルな分類手法です。

VM のヒント

VM の起動が完了したら、左上隅をクリックしてノートブックタブに切り替え、Jupyter Notebook を使って練習しましょう。

時々、Jupyter Notebook が読み込み完了するまで数秒待つ必要があります。Jupyter Notebook の制限により、操作の検証は自動化できません。

学習中に問題があった場合は、Labby にお問い合わせください。セッション後にフィードバックを提供してください。すぐに問題を解決いたします。

必要なライブラリをインポートする

まず、必要なライブラリをインポートする必要があります。これには、Numpy、Matplotlib、Scikit-learn のデータセット、NearestCentroid、および DecisionBoundaryDisplay が含まれます。

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import datasets
from sklearn.neighbors import NearestCentroid
from sklearn.inspection import DecisionBoundaryDisplay

データを読み込む

次に、Scikit-learn から iris データセットを読み込み、可視化の目的で最初の 2 つの特徴のみを選択します。

iris = datasets.load_iris()
X = iris.data[:, :2]
y = iris.target

カラーマップを作成する

Matplotlib の ListedColormap 関数を使って、可視化の目的で 2 つのカラーマップを作成します。

cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"])
cmap_bold = ListedColormap(["darkorange", "c", "darkblue"])

分類器を作成して適合させる

収縮値 0.2 の最寄りの重心分類器のインスタンスを作成し、データに適合させます。

clf = NearestCentroid(shrink_threshold=0.2)
clf.fit(X, y)

予測と精度の測定

入力データのクラスラベルを予測し、分類器の精度を測定します。

y_pred = clf.predict(X)
print("Accuracy: ", np.mean(y == y_pred))

決定境界を可視化する

Scikit-learn の DecisionBoundaryDisplay 関数を使って、分類器の決定境界を可視化します。

_, ax = plt.subplots()
DecisionBoundaryDisplay.from_estimator(
    clf, X, cmap=cmap_light, ax=ax, response_method="predict"
)

データポイントをプロットする

Matplotlib の scatter 関数を使って、入力データポイントをプロットします。

plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor="k", s=20)

タイトルと軸ラベルを追加する

Matplotlib の title、xlabel、ylabel 関数を使って、プロットにタイトルと軸ラベルを追加します。

plt.title("Nearest Centroid Classification")
plt.xlabel("Sepal length")
plt.ylabel("Sepal width")

プロットを表示する

Matplotlib の show 関数を使って、プロットを表示します。

plt.show()

まとめ

この実験では、Scikit-learn を使って最寄り重心分類を実装する方法を学びました。アイリスデータセットを読み込み、分類器を作成し、クラスラベルを予測し、精度を測定し、決定境界と入力データポイントを可視化しました。