Manifold Learning auf handschriftliche Ziffern

Beginner

This tutorial is from open-source community. Access the source code

Einführung

In diesem Lab werden wir verschiedene Techniken der Mannigfaltigkeits-Embedding auf dem Digits-Datensatz untersuchen. Wir werden verschiedene Techniken verwenden, um den Digits-Datensatz zu embedden, die Projektion der ursprünglichen Daten auf jedes Embedding zu plotten und die Ergebnisse, die aus verschiedenen Embedding-Methoden erhalten werden, zu vergleichen.

Tipps für die VM

Nachdem der VM-Start abgeschlossen ist, klicken Sie in der oberen linken Ecke, um zur Registerkarte Notebook zu wechseln und Jupyter Notebook für die Übung zu öffnen.

Manchmal müssen Sie einige Sekunden warten, bis Jupyter Notebook vollständig geladen ist. Die Validierung von Vorgängen kann aufgrund von Einschränkungen in Jupyter Notebook nicht automatisiert werden.

Wenn Sie bei der Lernphase Probleme haben, können Sie Labby gerne fragen. Geben Sie nach der Sitzung Feedback, und wir werden das Problem für Sie prompt beheben.

Laden des Digits-Datensatzes

Wir werden den Digits-Datensatz laden und nur sechs der zehn verfügbaren Klassen verwenden. Wir werden auch die ersten hundert Ziffern aus diesem Datensatz plotten.

## Load digits dataset
from sklearn.datasets import load_digits

digits = load_digits(n_class=6)
X, y = digits.data, digits.target
n_samples, n_features = X.shape
n_neighbors = 30

## Plot first hundred digits
import matplotlib.pyplot as plt

fig, axs = plt.subplots(nrows=10, ncols=10, figsize=(6, 6))
for idx, ax in enumerate(axs.ravel()):
    ax.imshow(X[idx].reshape((8, 8)), cmap=plt.cm.binary)
    ax.axis("off")
_ = fig.suptitle("A selection from the 64-dimensional digits dataset", fontsize=16)

Plot-Embedding-Funktion

Wir werden eine Hilfsfunktion definieren, um das Embedding zu plotten. Die Funktion nimmt die Embedding-Daten und den Titel für das Plot als Eingabe. Die Funktion wird jede Ziffer auf dem Embedding plotten und eine Anmerkungskästchen für eine Gruppe von Ziffern anzeigen.

import numpy as np
from matplotlib import offsetbox
from sklearn.preprocessing import MinMaxScaler

def plot_embedding(X, title):
    _, ax = plt.subplots()
    X = MinMaxScaler().fit_transform(X)

    for digit in digits.target_names:
        ax.scatter(
            *X[y == digit].T,
            marker=f"${digit}$",
            s=60,
            color=plt.cm.Dark2(digit),
            alpha=0.425,
            zorder=2,
        )
    shown_images = np.array([[1.0, 1.0]])  ## just something big
    for i in range(X.shape[0]):
        ## plot every digit on the embedding
        ## show an annotation box for a group of digits
        dist = np.sum((X[i] - shown_images) ** 2, 1)
        if np.min(dist) < 4e-3:
            ## don't show points that are too close
            continue
        shown_images = np.concatenate([shown_images, [X[i]]], axis=0)
        imagebox = offsetbox.AnnotationBbox(
            offsetbox.OffsetImage(digits.images[i], cmap=plt.cm.gray_r), X[i]
        )
        imagebox.set(zorder=1)
        ax.add_artist(imagebox)

    ax.set_title(title)
    ax.axis("off")

Vergleiche Embedding-Techniken

Wir werden verschiedene Embedding-Techniken mit unterschiedlichen Methoden vergleichen. Wir werden die projizierten Daten sowie die Rechenzeit speichern, die für jede Projection benötigt wird.

from sklearn.decomposition import TruncatedSVD
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.ensemble import RandomTreesEmbedding
from sklearn.manifold import (
    Isomap,
    LocallyLinearEmbedding,
    MDS,
    SpectralEmbedding,
    TSNE,
)
from sklearn.neighbors import NeighborhoodComponentsAnalysis
from sklearn.pipeline import make_pipeline
from sklearn.random_projection import SparseRandomProjection

embeddings = {
    "Random projection embedding": SparseRandomProjection(
        n_components=2, random_state=42
    ),
    "Truncated SVD embedding": TruncatedSVD(n_components=2),
    "Linear Discriminant Analysis embedding": LinearDiscriminantAnalysis(
        n_components=2
    ),
    "Isomap embedding": Isomap(n_neighbors=n_neighbors, n_components=2),
    "Standard LLE embedding": LocallyLinearEmbedding(
        n_neighbors=n_neighbors, n_components=2, method="standard"
    ),
    "Modified LLE embedding": LocallyLinearEmbedding(
        n_neighbors=n_neighbors, n_components=2, method="modified"
    ),
    "Hessian LLE embedding": LocallyLinearEmbedding(
        n_neighbors=n_neighbors, n_components=2, method="hessian"
    ),
    "LTSA LLE embedding": LocallyLinearEmbedding(
        n_neighbors=n_neighbors, n_components=2, method="ltsa"
    ),
    "MDS embedding": MDS(
        n_components=2, n_init=1, max_iter=120, n_jobs=2, normalized_stress="auto"
    ),
    "Random Trees embedding": make_pipeline(
        RandomTreesEmbedding(n_estimators=200, max_depth=5, random_state=0),
        TruncatedSVD(n_components=2),
    ),
    "Spectral embedding": SpectralEmbedding(
        n_components=2, random_state=0, eigen_solver="arpack"
    ),
    "t-SNE embeedding": TSNE(
        n_components=2,
        n_iter=500,
        n_iter_without_progress=150,
        n_jobs=2,
        random_state=0
    ),
    "NCA embedding": NeighborhoodComponentsAnalysis(
        n_components=2, init="pca", random_state=0
    ),
}

projections, timing = {}, {}
for name, transformer in embeddings.items():
    if name.startswith("Linear Discriminant Analysis"):
        data = X.copy()
        data.flat[:: X.shape[1] + 1] += 0.01  ## Make X invertible
    else:
        data = X

    print(f"Computing {name}...")
    start_time = time()
    projections[name] = transformer.fit_transform(data, y)
    timing[name] = time() - start_time

Ergebnisse plotten

Wir werden die resultierende Projection, die von jeder Methode gegeben wird, plotten.

for name in timing:
    title = f"{name} (time {timing[name]:.3f}s)"
    plot_embedding(projections[name], title)

plt.show()

Zusammenfassung

In diesem Lab haben wir verschiedene Manifold-Embedding-Techniken auf dem Digits-Datensatz untersucht. Wir haben verschiedene Techniken verwendet, um den Digits-Datensatz zu embedden, den Projektion der ursprünglichen Daten auf jedes Embedding geplottet und die Ergebnisse verglichen, die aus verschiedenen Embedding-Methoden erhalten wurden. Die Ergebnisse geben Einblicke in die Effektivität jeder Embedding-Methode bei der Gruppierung ähnlicher Ziffern zusammen im Embedding-Raum.