Agglomeratives Clustering auf dem Digits-Datensatz

Machine LearningMachine LearningBeginner
Jetzt üben

This tutorial is from open-source community. Access the source code

💡 Dieser Artikel wurde von AI-Assistenten übersetzt. Um die englische Version anzuzeigen, können Sie hier klicken

Einführung

In diesem Lab werden verschiedene Verknüpfungsmöglichkeiten für die Agglomerative Clustering auf einer 2D-Embedding des Digits-Datensatzes veranschaulicht. Ziel dieses Labs ist es, zu zeigen, wie sich verschiedene Verknüpfungsstrategien verhalten, und nicht, gute Cluster für die Ziffern zu finden. Deshalb funktioniert das Beispiel auf einer 2D-Embedding.

VM-Tipps

Nachdem der VM-Start abgeschlossen ist, klicken Sie in der oberen linken Ecke, um zur Registerkarte Notebook zu wechseln und Jupyter Notebook für die Übung zu nutzen.

Manchmal müssen Sie einige Sekunden warten, bis Jupyter Notebook vollständig geladen ist. Die Validierung von Vorgängen kann aufgrund von Einschränkungen in Jupyter Notebook nicht automatisiert werden.

Wenn Sie bei der Lernphase Probleme haben, können Sie Labby gerne fragen. Geben Sie nach der Sitzung Feedback, und wir werden das Problem für Sie prompt beheben.


Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL sklearn(("Sklearn")) -.-> sklearn/CoreModelsandAlgorithmsGroup(["Core Models and Algorithms"]) ml(("Machine Learning")) -.-> ml/FrameworkandSoftwareGroup(["Framework and Software"]) sklearn/CoreModelsandAlgorithmsGroup -.-> sklearn/cluster("Clustering") ml/FrameworkandSoftwareGroup -.-> ml/sklearn("scikit-learn") subgraph Lab Skills sklearn/cluster -.-> lab-49111{{"Agglomeratives Clustering auf dem Digits-Datensatz"}} ml/sklearn -.-> lab-49111{{"Agglomeratives Clustering auf dem Digits-Datensatz"}} end

Bibliotheken importieren

Wir beginnen, indem wir die erforderlichen Bibliotheken für dieses Lab importieren. Wir werden numpy, matplotlib, manifold und datasets aus scikit-learn verwenden, um Agglomerative Clustering durchzuführen.

import numpy as np
from matplotlib import pyplot as plt
from time import time
from sklearn import manifold, datasets
from sklearn.cluster import AgglomerativeClustering

Datensatz laden und vorbereiten

Wir laden den Digits-Datensatz und bereiten ihn für das Clustering vor, indem wir die Daten und die Ziel-Labels extrahieren. Wir setzen auch den Zufallszahlengenerator auf Null, um die Reproduzierbarkeit zu gewährleisten.

digits = datasets.load_digits()
X, y = digits.data, digits.target
n_samples, n_features = X.shape
np.random.seed(0)

Datensatz visualisieren

Wir visualisieren den Datensatz, indem wir eine 2D-Embedding des Digits-Datensatzes mit manifold.SpectralEmbedding() berechnen und den Scatterplot mit unterschiedlichen Markern für jede Ziffer zeichnen.

def plot_dataset(X_red):
    x_min, x_max = np.min(X_red, axis=0), np.max(X_red, axis=0)
    X_red = (X_red - x_min) / (x_max - x_min)

    plt.figure(figsize=(6, 4))
    for digit in digits.target_names:
        plt.scatter(
            *X_red[y == digit].T,
            marker=f"${digit}$",
            s=50,
            alpha=0.5,
        )

    plt.xticks([])
    plt.yticks([])
    plt.title('Digits Dataset Scatter Plot', size=17)
    plt.axis("off")
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

print("Computing embedding")
X_red = manifold.SpectralEmbedding(n_components=2).fit_transform(X)
print("Done.")
plot_dataset(X_red)

Agglomeratives Clustering mit verschiedenen Verknüpfungsstrategien

Wir führen Agglomeratives Clustering mit verschiedenen Verknüpfungsstrategien durch: ward, average, complete und single. Wir setzen die Anzahl der Cluster auf 10 für alle Strategien. Anschließend zeichnen wir die Clustering-Ergebnisse mit unterschiedlichen Farben für jede Ziffer.

def plot_clustering(X_red, labels, title=None):
    x_min, x_max = np.min(X_red, axis=0), np.max(X_red, axis=0)
    X_red = (X_red - x_min) / (x_max - x_min)

    plt.figure(figsize=(6, 4))
    for digit in digits.target_names:
        plt.scatter(
            *X_red[y == digit].T,
            marker=f"${digit}$",
            s=50,
            c=plt.cm.nipy_spectral(labels[y == digit] / 10),
            alpha=0.5,
        )

    plt.xticks([])
    plt.yticks([])
    if title is not None:
        plt.title(title, size=17)
    plt.axis("off")
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

for linkage in ("ward", "average", "complete", "single"):
    clustering = AgglomerativeClustering(linkage=linkage, n_clusters=10)
    t0 = time()
    clustering.fit(X_red)
    print("%s :\t%.2fs" % (linkage, time() - t0))

    plot_clustering(X_red, clustering.labels_, "%s linkage" % linkage)

plt.show()

Zusammenfassung

In diesem Lab haben wir gelernt, wie man Agglomeratives Clustering auf dem Digits-Datensatz mit verschiedenen Verknüpfungsstrategien durchführt. Wir haben auch den Datensatz und die Clustering-Ergebnisse für jede Strategie visualisiert. Die Ergebnisse zeigen, dass verschiedene Verknüpfungsstrategien unterschiedliche Clustering-Ergebnisse erzeugen, und wir sollten die Strategie wählen, die am besten unseren Anforderungen entspricht.