Einführung
In diesem Lab wird die Kreuzvalidierung mit einer Support Vector Machine (SVM) auf dem Digits-Datensatz durchgeführt. Dies ist ein Klassifizierungsproblem, bei dem die Aufgabe besteht, Ziffern aus Bildern von handschriftlichen Ziffern zu identifizieren.
Tipps für die VM
Nachdem der VM-Start abgeschlossen ist, klicken Sie in der oberen linken Ecke, um zur Registerkarte Notebook zu wechseln und Jupyter Notebook für die Übung zu nutzen.
Manchmal müssen Sie einige Sekunden warten, bis Jupyter Notebook vollständig geladen ist. Die Validierung von Vorgängen kann aufgrund von Einschränkungen in Jupyter Notebook nicht automatisiert werden.
Wenn Sie bei der Lernphase Probleme haben, können Sie Labby gerne fragen. Geben Sie nach der Sitzung Feedback, und wir werden das Problem für Sie prompt beheben.
Lade den Datensatz
Zunächst müssen wir den Digits-Datensatz aus scikit-learn laden und ihn in Features und Labels unterteilen.
import numpy as np
from sklearn import datasets
X, y = datasets.load_digits(return_X_y=True)
Erstelle ein Support Vector Machine (SVM)-Modell
Als nächstes erstellen wir ein SVM-Modell mit einem linearen Kernel.
from sklearn import svm
svc = svm.SVC(kernel="linear")
Definiere die Hyperparameterwerte zum Testen
Wir werden verschiedene Werte des Regularisierungsparameters C testen, der das Kompromissverhältnis zwischen der Maximalisierung der Margin und der Minimierung der Klassifizierungsfehler steuert. Wir werden 10 logarithmisch gleichmäßig verteilte Werte zwischen 10^-10 und 1 testen.
C_s = np.logspace(-10, 0, 10)
Führe Kreuzvalidierung durch und protokolliere die Ergebnisse
Für jeden Wert von C führen wir eine 10-fache Kreuzvalidierung durch und protokollieren den Mittelwert und die Standardabweichung der Scores.
from sklearn.model_selection import cross_val_score
scores = list()
scores_std = list()
for C in C_s:
svc.C = C
this_scores = cross_val_score(svc, X, y, n_jobs=1)
scores.append(np.mean(this_scores))
scores_std.append(np.std(this_scores))
Zeichne die Ergebnisse
Schließlich zeichnen wir die Mittelwerte der Scores als Funktion von C und fügen auch Fehlerbalken hinzu, um die Standardabweichung zu visualisieren.
import matplotlib.pyplot as plt
plt.figure()
plt.semilogx(C_s, scores)
plt.semilogx(C_s, np.array(scores) + np.array(scores_std), "b--")
plt.semilogx(C_s, np.array(scores) - np.array(scores_std), "b--")
locs, labels = plt.yticks()
plt.yticks(locs, list(map(lambda x: "%g" % x, locs)))
plt.ylabel("CV score")
plt.xlabel("Parameter C")
plt.ylim(0, 1.1)
plt.show()
Zusammenfassung
In diesem Lab haben wir eine 10-fache Kreuzvalidierung mit einem SVM-Modell auf dem digits-Datensatz durchgeführt und verschiedene Werte des Regularisierungsparameters C getestet. Wir haben die Ergebnisse geplottet, um die Beziehung zwischen C und dem mittleren Kreuzvalidierungsscore zu visualisieren. Dies ist eine nützliche Technik zum Einstellen von Hyperparametern und zur Beurteilung der Modellleistung.