Einführung
In diesem Lab werden Sie durch den Prozess geführt, die Mean-Shift-Clustering-Algorithmus mit der Scikit-learn-Bibliothek in Python zu implementieren. Sie werden lernen, wie man Stichprobendaten generiert, die Clustering mit MeanShift berechnet und die Ergebnisse darstellt.
Tipps für die VM
Nachdem der VM-Start abgeschlossen ist, klicken Sie in der oberen linken Ecke, um zur Registerkarte Notebook zu wechseln und Jupyter Notebook für die Übung zu nutzen.
Manchmal müssen Sie einige Sekunden warten, bis Jupyter Notebook vollständig geladen ist. Die Validierung von Vorgängen kann aufgrund von Einschränkungen in Jupyter Notebook nicht automatisiert werden.
Wenn Sie bei der Lernphase Probleme haben, können Sie Labby gerne fragen. Geben Sie nach der Sitzung Feedback, und wir werden das Problem für Sie prompt beheben.
Importieren der erforderlichen Bibliotheken
Zunächst müssen wir die erforderlichen Bibliotheken importieren: numpy, sklearn.cluster, sklearn.datasets und matplotlib.pyplot.
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
Generieren von Stichprobendaten
Als nächstes werden wir Stichprobendaten mit der Funktion make_blobs aus dem Modul sklearn.datasets generieren. Wir werden einen Datensatz mit 10.000 Stichproben und drei Clustern mit Mittelpunkten bei [[1, 1], [-1, -1], [1, -1]] und einer Standardabweichung von 0,6 erstellen.
centers = [[1, 1], [-1, -1], [1, -1]]
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)
Berechnen der Clustering mit MeanShift
Jetzt werden wir die Clustering mit dem Mean-Shift-Algorithmus mit der Klasse MeanShift aus dem Modul sklearn.cluster berechnen. Wir werden die Funktion estimate_bandwidth verwenden, um den Bandbreitenparameter automatisch zu erkennen, der die Größe der Einflussregion für jeden Punkt darstellt.
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_
Ergebnisse plotten
Schließlich werden wir die Ergebnisse mit der Bibliothek matplotlib.pyplot plotten. Wir werden für jedes Cluster unterschiedliche Farben und Marker verwenden und auch die Clusterzentren plotten.
plt.figure(1)
plt.clf()
colors = ["#dede00", "#377eb8", "#f781bf"]
markers = ["x", "o", "^"]
for k, col in zip(range(n_clusters_), colors):
my_members = labels == k
cluster_center = cluster_centers[k]
plt.plot(X[my_members, 0], X[my_members, 1], markers[k], color=col)
plt.plot(
cluster_center[0],
cluster_center[1],
markers[k],
markerfacecolor=col,
markeredgecolor="k",
markersize=14,
)
plt.title("Estimated number of clusters: %d" % n_clusters_)
plt.show()
Zusammenfassung
In diesem Lab haben wir gelernt, wie man den Mean-Shift-Clustering-Algorithmus mit der Scikit-learn-Bibliothek in Python implementiert. Wir haben Stichprobendaten generiert, die Clustering mit MeanShift berechnet und die Ergebnisse geplottet. Am Ende dieses Labs sollten Sie ein besseres Verständnis dafür haben, wie man den Mean-Shift-Algorithmus verwendet, um Daten zu clustern.