Mean-Shift クラスタリングアルゴリズム

Machine LearningMachine LearningBeginner
オンラインで実践に進む

This tutorial is from open-source community. Access the source code

💡 このチュートリアルは英語版からAIによって翻訳されています。原文を確認するには、 ここをクリックしてください

はじめに

この実験では、Python の Scikit-learn ライブラリを使用して、Mean-Shift クラスタリングアルゴリズムを実装するプロセスを案内します。サンプルデータを生成し、MeanShift を使ってクラスタリングを計算し、結果をプロットする方法を学びます。

VM のヒント

VM の起動が完了したら、左上隅をクリックして ノートブック タブに切り替え、Jupyter Notebook を使って練習しましょう。

時々、Jupyter Notebook が読み込み完了するまで数秒待つ必要がある場合があります。Jupyter Notebook の制限により、操作の検証を自動化することはできません。

学習中に問題に遭遇した場合は、Labby にお問い合わせください。セッション後にフィードバックを提供してください。すぐに問題を解決いたします。

必要なライブラリをインポートする

まず、必要なライブラリをインポートする必要があります。numpysklearn.clustersklearn.datasets、およびmatplotlib.pyplotです。

import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt

サンプルデータを生成する

次に、sklearn.datasets モジュールの make_blobs 関数を使ってサンプルデータを生成します。中心が [[1, 1], [-1, -1], [1, -1]] で標準偏差が 0.6 の 3 つのクラスタを持つ 10,000 個のサンプルからなるデータセットを作成します。

centers = [[1, 1], [-1, -1], [1, -1]]
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)

MeanShift を使ってクラスタリングを計算する

次に、sklearn.cluster モジュールの MeanShift クラスを使って Mean-Shift アルゴリズムを使ってクラスタリングを計算します。各点の影響領域のサイズを表すバンド幅パラメータを自動的に検出するために、estimate_bandwidth 関数を使用します。

bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_

結果をプロットする

最後に、matplotlib.pyplot ライブラリを使って結果をプロットします。各クラスタに対して異なる色とマーカーを使用し、クラスタ中心もプロットします。

plt.figure(1)
plt.clf()

colors = ["#dede00", "#377eb8", "#f781bf"]
markers = ["x", "o", "^"]

for k, col in zip(range(n_clusters_), colors):
    my_members = labels == k
    cluster_center = cluster_centers[k]
    plt.plot(X[my_members, 0], X[my_members, 1], markers[k], color=col)
    plt.plot(
        cluster_center[0],
        cluster_center[1],
        markers[k],
        markerfacecolor=col,
        markeredgecolor="k",
        markersize=14,
    )
plt.title("Estimated number of clusters: %d" % n_clusters_)
plt.show()

まとめ

この実験では、Python の Scikit-learn ライブラリを使って Mean-Shift クラスタリングアルゴリズムを実装する方法を学びました。サンプルデータを生成し、MeanShift を使ってクラスタリングを計算し、結果をプロットしました。この実験が終わるまでに、データをクラスタリングするための Mean-Shift アルゴリズムの使い方をより深く理解しているはずです。