简介
支持向量机(Support Vector Machines,SVM)用于分类和回归分析。SVM 会找到尽可能最佳的直线或超平面,将数据分隔为不同的类别。使来自不同类别的两个最接近数据点之间的距离最大化的直线或超平面称为间隔(margin)。在本实验中,我们将探讨参数 C 如何影响线性 SVM 中的间隔。
虚拟机使用提示
虚拟机启动完成后,点击左上角切换到“笔记本”(Notebook)标签页,以访问 Jupyter Notebook 进行练习。
有时,你可能需要等待几秒钟,以便 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。
如果你在学习过程中遇到问题,随时向 Labby 提问。课程结束后提供反馈,我们会立即为你解决问题。
导入库
我们首先导入必要的库,包括 numpy、matplotlib 和 scikit-learn。
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
生成数据
我们使用 numpy 的 random.randn 函数生成 40 个可分离的点。前 20 个点的均值为 [-2, -2],接下来 20 个点的均值为 [2, 2]。然后我们为前 20 个点分配类别标签 0,为接下来 20 个点分配类别标签 1。
np.random.seed(0)
X = np.r_[np.random.randn(20, 2) - [2, 2], np.random.randn(20, 2) + [2, 2]]
Y = [0] * 20 + [1] * 20
拟合模型
我们使用 scikit-learn 的 SVC 类来拟合 SVM 模型。对于无正则化的情况,我们将核函数设置为线性,惩罚参数 C 设置为 1;对于正则化的情况,惩罚参数 C 设置为 0.05。然后,我们使用模型的系数和截距来计算分离超平面。
for name, penalty in (("unreg", 1), ("reg", 0.05)):
clf = svm.SVC(kernel="linear", C=penalty)
clf.fit(X, Y)
w = clf.coef_[0]
a = -w[0] / w[1]
xx = np.linspace(-5, 5)
yy = a * xx - (clf.intercept_[0]) / w[1]
计算间隔
我们计算分离超平面的间隔。首先,我们使用模型的系数来计算间隔距离。然后,我们利用超平面的斜率计算从支持向量到超平面的垂直距离。最后,我们绘制直线、数据点以及离平面最近的向量。
margin = 1 / np.sqrt(np.sum(clf.coef_**2))
yy_down = yy - np.sqrt(1 + a**2) * margin
yy_up = yy + np.sqrt(1 + a**2) * margin
plt.plot(xx, yy, "k-")
plt.plot(xx, yy_down, "k--")
plt.plot(xx, yy_up, "k--")
plt.scatter(
clf.support_vectors_[:, 0],
clf.support_vectors_[:, 1],
s=80,
facecolors="none",
zorder=10,
edgecolors="k",
cmap=plt.get_cmap("RdBu"),
)
plt.scatter(
X[:, 0], X[:, 1], c=Y, zorder=10, cmap=plt.get_cmap("RdBu"), edgecolors="k"
)
plt.axis("tight")
x_min = -4.8
x_max = 4.2
y_min = -6
y_max = 6
绘制等高线
我们绘制决策函数的等高线。首先,我们使用 xx 和 yy 数组创建一个网格。然后,我们将网格重塑为二维数组,并应用 SVC 类的 decision_function 方法来获取预测值。接着,我们使用 contourf 方法绘制等高线。
YY, XX = np.meshgrid(yy, xx)
xy = np.vstack([XX.ravel(), YY.ravel()]).T
Z = clf.decision_function(xy).reshape(XX.shape)
plt.contourf(XX, YY, Z, cmap=plt.get_cmap("RdBu"), alpha=0.5, linestyles=["-"])
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.xticks(())
plt.yticks(())
显示图表
我们展示无正则化和正则化情况下的图表。
plt.show()
总结
在本实验中,我们探究了参数 C 如何影响线性支持向量机(SVM)中的间隔。我们生成了数据,拟合了模型,计算了间隔,并绘制了结果。然后,我们展示了无正则化和正则化情况下的图表。