Kreuzvalidierung auf dem Digits-Datensatz

Beginner

This tutorial is from open-source community. Access the source code

Einführung

In diesem Lab wird die Kreuzvalidierung mit einer Support Vector Machine (SVM) auf dem Digits-Datensatz durchgeführt. Dies ist ein Klassifizierungsproblem, bei dem die Aufgabe besteht, Ziffern aus Bildern von handschriftlichen Ziffern zu identifizieren.

Tipps für die VM

Nachdem der VM-Start abgeschlossen ist, klicken Sie in der oberen linken Ecke, um zur Registerkarte Notebook zu wechseln und Jupyter Notebook für die Übung zu nutzen.

Manchmal müssen Sie einige Sekunden warten, bis Jupyter Notebook vollständig geladen ist. Die Validierung von Vorgängen kann aufgrund von Einschränkungen in Jupyter Notebook nicht automatisiert werden.

Wenn Sie bei der Lernphase Probleme haben, können Sie Labby gerne fragen. Geben Sie nach der Sitzung Feedback, und wir werden das Problem für Sie prompt beheben.

Lade den Datensatz

Zunächst müssen wir den Digits-Datensatz aus scikit-learn laden und ihn in Features und Labels unterteilen.

import numpy as np
from sklearn import datasets

X, y = datasets.load_digits(return_X_y=True)

Erstelle ein Support Vector Machine (SVM)-Modell

Als nächstes erstellen wir ein SVM-Modell mit einem linearen Kernel.

from sklearn import svm

svc = svm.SVC(kernel="linear")

Definiere die Hyperparameterwerte zum Testen

Wir werden verschiedene Werte des Regularisierungsparameters C testen, der das Kompromissverhältnis zwischen der Maximalisierung der Margin und der Minimierung der Klassifizierungsfehler steuert. Wir werden 10 logarithmisch gleichmäßig verteilte Werte zwischen 10^-10 und 1 testen.

C_s = np.logspace(-10, 0, 10)

Führe Kreuzvalidierung durch und protokolliere die Ergebnisse

Für jeden Wert von C führen wir eine 10-fache Kreuzvalidierung durch und protokollieren den Mittelwert und die Standardabweichung der Scores.

from sklearn.model_selection import cross_val_score

scores = list()
scores_std = list()
for C in C_s:
    svc.C = C
    this_scores = cross_val_score(svc, X, y, n_jobs=1)
    scores.append(np.mean(this_scores))
    scores_std.append(np.std(this_scores))

Zeichne die Ergebnisse

Schließlich zeichnen wir die Mittelwerte der Scores als Funktion von C und fügen auch Fehlerbalken hinzu, um die Standardabweichung zu visualisieren.

import matplotlib.pyplot as plt

plt.figure()
plt.semilogx(C_s, scores)
plt.semilogx(C_s, np.array(scores) + np.array(scores_std), "b--")
plt.semilogx(C_s, np.array(scores) - np.array(scores_std), "b--")
locs, labels = plt.yticks()
plt.yticks(locs, list(map(lambda x: "%g" % x, locs)))
plt.ylabel("CV score")
plt.xlabel("Parameter C")
plt.ylim(0, 1.1)
plt.show()

Zusammenfassung

In diesem Lab haben wir eine 10-fache Kreuzvalidierung mit einem SVM-Modell auf dem digits-Datensatz durchgeführt und verschiedene Werte des Regularisierungsparameters C getestet. Wir haben die Ergebnisse geplottet, um die Beziehung zwischen C und dem mittleren Kreuzvalidierungsscore zu visualisieren. Dies ist eine nützliche Technik zum Einstellen von Hyperparametern und zur Beurteilung der Modellleistung.