Kreuzvalidierungstechniken mit Scikit-Learn

Machine LearningMachine LearningBeginner
Jetzt üben

This tutorial is from open-source community. Access the source code

💡 Dieser Artikel wurde von AI-Assistenten übersetzt. Um die englische Version anzuzeigen, können Sie hier klicken

Einführung

Die Kreuzvalidierung ist eine Technik zur Evaluierung von Machine-Learning-Modellen, bei der mehrere Modelle auf verschiedenen Teilmengen der verfügbaren Daten trainiert und auf der komplementären Teilmenge ausgewertet werden. Dies hilft, Overfitting zu vermeiden und gewährleistet, dass das Modell verallgemeinerbar ist. In diesem Tutorial werden wir scikit-learn verwenden, um verschiedene Kreuzvalidierungstechniken und ihr Verhalten zu untersuchen.

Tipps für die VM

Nachdem der VM-Start abgeschlossen ist, klicken Sie in der oberen linken Ecke, um zur Registerkarte Notebook zu wechseln und Jupyter Notebook für die Übung zu nutzen.

Manchmal müssen Sie einige Sekunden warten, bis Jupyter Notebook vollständig geladen ist. Die Validierung von Vorgängen kann aufgrund von Einschränkungen in Jupyter Notebook nicht automatisiert werden.

Wenn Sie bei der Lernphase Probleme haben, können Sie Labby gerne fragen. Geben Sie nach der Sitzung Feedback, und wir werden das Problem für Sie prompt beheben.

Visualisiere die Daten

In diesem Schritt werden wir die Daten visualisieren, mit denen wir arbeiten werden. Die Daten bestehen aus 100 zufällig generierten Eingabedatenpunkten, 3 Klassen, die ungleichmäßig über die Datenpunkte verteilt sind, und 10 "Gruppen", die gleichmäßig über die Datenpunkte verteilt sind.

## Generate the class/group data
n_points = 100
X = rng.randn(100, 10)

percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])

## Generate uneven groups
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))


def visualize_groups(classes, groups, name):
    ## Visualize dataset groups
    fig, ax = plt.subplots()
    ax.scatter(
        range(len(groups)),
        [0.5] * len(groups),
        c=groups,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.scatter(
        range(len(groups)),
        [3.5] * len(groups),
        c=classes,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.set(
        ylim=[-1, 5],
        yticks=[0.5, 3.5],
        yticklabels=["Data\ngroup", "Data\nclass"],
        xlabel="Sample index",
    )


visualize_groups(y, groups, "no groups")

Visualisiere die Daten

In diesem Schritt visualisieren wir die Daten, mit denen wir arbeiten werden. Die Daten bestehen aus 100 zufällig generierten Eingabedatenpunkten, 3 Klassen, die ungleichmäßig über die Datenpunkte verteilt sind, und 10 "Gruppen", die gleichmäßig über die Datenpunkte verteilt sind.

## Generiere die Klassendaten/Gruppendaten
n_points = 100
X = rng.randn(100, 10)

percentiles_classes = [0,1, 0,3, 0,6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])

## Generiere ungleiche Gruppen
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))


def visualize_groups(classes, groups, name):
    ## Visualisiere die Datensatzgruppen
    fig, ax = plt.subplots()
    ax.scatter(
        range(len(groups)),
        [0,5] * len(groups),
        c=groups,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.scatter(
        range(len(groups)),
        [3,5] * len(groups),
        c=classes,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.set(
        ylim=[-1, 5],
        yticks=[0,5, 3,5],
        yticklabels=["Data\ngroup", "Data\nclass"],
        xlabel="Sample index",
    )


visualize_groups(y, groups, "no groups")

Visualisiere die Indizes der Kreuzvalidierung

In diesem Schritt definieren wir eine Funktion, um das Verhalten jedes Kreuzvalidierungsobjekts zu visualisieren. Wir werden 4 Aufteilungen der Daten vornehmen. Bei jeder Aufteilung visualisieren wir die Indizes, die für den Trainingssatz (blau) und den Testsatz (rot) ausgewählt wurden.

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    """Create a sample plot for indices of a cross-validation object."""

    ## Generate the training/testing visualizations for each CV split
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        ## Fill in indices with the training/test groups
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        ## Visualize the results
        ax.scatter(
            range(len(indices)),
            [ii + 0.5] * len(indices),
            c=indices,
            marker="_",
            lw=lw,
            cmap=cmap_cv,
            vmin=-0.2,
            vmax=1.2,
        )

    ## Plot the data classes and groups at the end
    ax.scatter(
        range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
    )

    ax.scatter(
        range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
    )

    ## Formatting
    yticklabels = list(range(n_splits)) + ["class", "group"]
    ax.set(
        yticks=np.arange(n_splits + 2) + 0.5,
        yticklabels=yticklabels,
        xlabel="Sample index",
        ylabel="CV iteration",
        ylim=[n_splits + 2.2, -0.2],
        xlim=[0, 100],
    )
    ax.set_title("{}".format(type(cv).__name__), fontsize=15)
    return ax

Visualisiere die Indizes der Kreuzvalidierung

In diesem Schritt definieren wir eine Funktion, um das Verhalten jedes Kreuzvalidierungsobjekts zu visualisieren. Wir werden 4 Aufteilungen der Daten vornehmen. Bei jeder Aufteilung visualisieren wir die Indizes, die für den Trainingssatz (blau) und den Testsatz (rot) ausgewählt wurden.

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    """Erstelle ein Beispielplot für die Indizes eines Kreuzvalidierungsobjekts."""

    ## Generiere die Visualisierungen für die Trainings- und Testdaten für jede Kreuzvalidierungsaufteilung
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        ## Fülle die Indizes mit den Trainings- und Testgruppen
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        ## Visualisiere die Ergebnisse
        ax.scatter(
            range(len(indices)),
            [ii + 0.5] * len(indices),
            c=indices,
            marker="_",
            lw=lw,
            cmap=cmap_cv,
            vmin=-0.2,
            vmax=1.2,
        )

    ## Plotte die Datenklassen und -gruppen am Ende
    ax.scatter(
        range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
    )

    ax.scatter(
        range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
    )

    ## Formatierung
    yticklabels = list(range(n_splits)) + ["class", "group"]
    ax.set(
        yticks=np.arange(n_splits + 2) + 0.5,
        yticklabels=yticklabels,
        xlabel="Sample index",
        ylabel="CV iteration",
        ylim=[n_splits + 2.2, -0.2],
        xlim=[0, 100],
    )
    ax.set_title("{}".format(type(cv).__name__), fontsize=15)
    return ax

Vergleiche Kreuzvalidierungsobjekte

In diesem Schritt werden wir das Kreuzvalidierungsverhalten für verschiedene scikit-learn-Kreuzvalidierungsobjekte vergleichen. Wir werden durch mehrere übliche Kreuzvalidierungsobjekte iterieren und das Verhalten jedes visualisieren. Beachten Sie, wie einige die Gruppierungs-/Klasseninformation verwenden, während andere dies nicht tun.

cvs = [
    KFold,
    GroupKFold,
    ShuffleSplit,
    StratifiedKFold,
    StratifiedGroupKFold,
    GroupShuffleSplit,
    StratifiedShuffleSplit,
    TimeSeriesSplit,
]

for cv in cvs:
    this_cv = cv(n_splits=n_splits)
    fig, ax = plt.subplots(figsize=(6, 3))
    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend(
        [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
        ["Testing set", "Training set"],
        loc=(1.02, 0.8),
    )
    ## Make the legend fit
    plt.tight_layout()
    fig.subplots_adjust(right=0.7)
plt.show()

Vergleiche Kreuzvalidierungsobjekte

In diesem Schritt werden wir das Kreuzvalidierungsverhalten für verschiedene scikit-learn-Kreuzvalidierungsobjekte vergleichen. Wir werden durch mehrere übliche Kreuzvalidierungsobjekte iterieren und das Verhalten jedes visualisieren. Beachten Sie, wie einige die Gruppierungs-/Klasseninformation verwenden, während andere dies nicht tun.

cvs = [
    KFold,
    GroupKFold,
    ShuffleSplit,
    StratifiedKFold,
    StratifiedGroupKFold,
    GroupShuffleSplit,
    StratifiedShuffleSplit,
    TimeSeriesSplit,
]

for cv in cvs:
    this_cv = cv(n_splits=n_splits)
    fig, ax = plt.subplots(figsize=(6, 3))
    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend(
        [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
        ["Testing set", "Training set"],
        loc=(1.02, 0.8),
    )
    ## Mach die Legende passen
    plt.tight_layout()
    fig.subplots_adjust(right=0.7)
plt.show()

Zusammenfassung

In diesem Tutorial haben wir scikit-learn verwendet, um verschiedene Kreuzvalidierungstechniken und ihr Verhalten zu untersuchen. Wir haben die Daten visualisiert, eine Funktion definiert, um die Indizes der Kreuzvalidierung zu visualisieren, und das Kreuzvalidierungsverhalten für verschiedene scikit-learn-Kreuzvalidierungsobjekte verglichen.