Mean-Shift-Clustering-Algorithmus

Machine LearningMachine LearningBeginner
Jetzt üben

This tutorial is from open-source community. Access the source code

💡 Dieser Artikel wurde von AI-Assistenten übersetzt. Um die englische Version anzuzeigen, können Sie hier klicken

Einführung

In diesem Lab werden Sie durch den Prozess geführt, die Mean-Shift-Clustering-Algorithmus mit der Scikit-learn-Bibliothek in Python zu implementieren. Sie werden lernen, wie man Stichprobendaten generiert, die Clustering mit MeanShift berechnet und die Ergebnisse darstellt.

Tipps für die VM

Nachdem der VM-Start abgeschlossen ist, klicken Sie in der oberen linken Ecke, um zur Registerkarte Notebook zu wechseln und Jupyter Notebook für die Übung zu nutzen.

Manchmal müssen Sie einige Sekunden warten, bis Jupyter Notebook vollständig geladen ist. Die Validierung von Vorgängen kann aufgrund von Einschränkungen in Jupyter Notebook nicht automatisiert werden.

Wenn Sie bei der Lernphase Probleme haben, können Sie Labby gerne fragen. Geben Sie nach der Sitzung Feedback, und wir werden das Problem für Sie prompt beheben.

Importieren der erforderlichen Bibliotheken

Zunächst müssen wir die erforderlichen Bibliotheken importieren: numpy, sklearn.cluster, sklearn.datasets und matplotlib.pyplot.

import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt

Generieren von Stichprobendaten

Als nächstes werden wir Stichprobendaten mit der Funktion make_blobs aus dem Modul sklearn.datasets generieren. Wir werden einen Datensatz mit 10.000 Stichproben und drei Clustern mit Mittelpunkten bei [[1, 1], [-1, -1], [1, -1]] und einer Standardabweichung von 0,6 erstellen.

centers = [[1, 1], [-1, -1], [1, -1]]
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)

Berechnen der Clustering mit MeanShift

Jetzt werden wir die Clustering mit dem Mean-Shift-Algorithmus mit der Klasse MeanShift aus dem Modul sklearn.cluster berechnen. Wir werden die Funktion estimate_bandwidth verwenden, um den Bandbreitenparameter automatisch zu erkennen, der die Größe der Einflussregion für jeden Punkt darstellt.

bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_

Ergebnisse plotten

Schließlich werden wir die Ergebnisse mit der Bibliothek matplotlib.pyplot plotten. Wir werden für jedes Cluster unterschiedliche Farben und Marker verwenden und auch die Clusterzentren plotten.

plt.figure(1)
plt.clf()

colors = ["#dede00", "#377eb8", "#f781bf"]
markers = ["x", "o", "^"]

for k, col in zip(range(n_clusters_), colors):
    my_members = labels == k
    cluster_center = cluster_centers[k]
    plt.plot(X[my_members, 0], X[my_members, 1], markers[k], color=col)
    plt.plot(
        cluster_center[0],
        cluster_center[1],
        markers[k],
        markerfacecolor=col,
        markeredgecolor="k",
        markersize=14,
    )
plt.title("Estimated number of clusters: %d" % n_clusters_)
plt.show()

Zusammenfassung

In diesem Lab haben wir gelernt, wie man den Mean-Shift-Clustering-Algorithmus mit der Scikit-learn-Bibliothek in Python implementiert. Wir haben Stichprobendaten generiert, die Clustering mit MeanShift berechnet und die Ergebnisse geplottet. Am Ende dieses Labs sollten Sie ein besseres Verständnis dafür haben, wie man den Mean-Shift-Algorithmus verwendet, um Daten zu clustern.