Einführung
Support Vector Machines (SVMs) werden zur Klassifizierung und Regressionsanalyse verwendet. SVMs finden die besten möglichen Linien oder Hyperebenen, die die Daten in verschiedene Klassen trennen. Die Linie oder Hyperebene, die den Abstand zwischen den beiden nächsten Datenpunkten aus verschiedenen Klassen maximiert, wird als Margin bezeichnet. In diesem Lab werden wir untersuchen, wie der Parameter C den Margin in einer linearen SVM beeinflusst.
VM-Tipps
Nachdem der VM-Start abgeschlossen ist, klicken Sie in der oberen linken Ecke, um zur Registerkarte Notebook zu wechseln und Jupyter Notebook für die Übung zu nutzen.
Manchmal müssen Sie einige Sekunden warten, bis Jupyter Notebook vollständig geladen ist. Die Validierung von Vorgängen kann aufgrund von Einschränkungen in Jupyter Notebook nicht automatisiert werden.
Wenn Sie bei der Lernphase Probleme haben, können Sie Labby gerne fragen. Geben Sie nach der Sitzung Feedback, und wir werden das Problem für Sie prompt beheben.
Bibliotheken importieren
Wir beginnen mit dem Import der erforderlichen Bibliotheken, einschließlich von numpy, matplotlib und scikit-learn.
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
Daten generieren
Wir generieren 40 trennbare Punkte mit der Funktion random.randn von numpy. Die ersten 20 Punkte haben einen Mittelwert von [-2, -2], und die nächsten 20 Punkte haben einen Mittelwert von [2, 2]. Anschließend weisen wir den ersten 20 Punkten ein Klassenlabel von 0 und den nächsten 20 Punkten ein Klassenlabel von 1 zu.
np.random.seed(0)
X = np.r_[np.random.randn(20, 2) - [2, 2], np.random.randn(20, 2) + [2, 2]]
Y = [0] * 20 + [1] * 20
Modell anpassen
Wir passen das SVM-Modell mit der Klasse SVC von scikit-learn an. Wir setzen den Kernel auf linear und den Strafparameter C auf 1 für den nicht regulierten Fall und auf 0,05 für den regulierten Fall. Anschließend berechnen wir die trennende Hyperebene mithilfe der Koeffizienten und des Interzepts des Modells.
for name, penalty in (("unreg", 1), ("reg", 0.05)):
clf = svm.SVC(kernel="linear", C=penalty)
clf.fit(X, Y)
w = clf.coef_[0]
a = -w[0] / w[1]
xx = np.linspace(-5, 5)
yy = a * xx - (clf.intercept_[0]) / w[1]
Margen berechnen
Wir berechnen die Margen für die trennende Hyperebene. Zunächst berechnen wir die Margenentfernung mithilfe der Koeffizienten des Modells. Anschließend berechnen wir die vertikale Entfernung von den Supportvektoren zur Hyperebene mithilfe der Steigung der Hyperebene. Schließlich zeichnen wir die Linie, die Punkte und die am nächsten liegenden Vektoren zur Ebene.
margin = 1 / np.sqrt(np.sum(clf.coef_**2))
yy_down = yy - np.sqrt(1 + a**2) * margin
yy_up = yy + np.sqrt(1 + a**2) * margin
plt.plot(xx, yy, "k-")
plt.plot(xx, yy_down, "k--")
plt.plot(xx, yy_up, "k--")
plt.scatter(
clf.support_vectors_[:, 0],
clf.support_vectors_[:, 1],
s=80,
facecolors="none",
zorder=10,
edgecolors="k",
cmap=plt.get_cmap("RdBu"),
)
plt.scatter(
X[:, 0], X[:, 1], c=Y, zorder=10, cmap=plt.get_cmap("RdBu"), edgecolors="k"
)
plt.axis("tight")
x_min = -4.8
x_max = 4.2
y_min = -6
y_max = 6
Kontur plotten
Wir zeichnen die Kontur der Entscheidungsfunktion. Zunächst erstellen wir ein Gitternetz mit den Arrays xx und yy. Anschließend formen wir das Gitternetz in ein 2D-Array um und wenden die decision_function-Methode der SVC-Klasse an, um die vorhergesagten Werte zu erhalten. Anschließend zeichnen wir die Kontur mit der contourf-Methode.
YY, XX = np.meshgrid(yy, xx)
xy = np.vstack([XX.ravel(), YY.ravel()]).T
Z = clf.decision_function(xy).reshape(XX.shape)
plt.contourf(XX, YY, Z, cmap=plt.get_cmap("RdBu"), alpha=0.5, linestyles=["-"])
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.xticks(())
plt.yticks(())
Die Plots anzeigen
Wir zeigen die Plots für die nicht regulierten und regulierten Fälle an.
plt.show()
Zusammenfassung
In diesem Lab haben wir untersucht, wie der Parameter C die Margen in einem linearen SVM beeinflusst. Wir haben Daten generiert, das Modell angepasst, die Margen berechnet und die Ergebnisse geplottet. Anschließend haben wir die Plots für die nicht regulierten und regulierten Fälle angezeigt.