使用支持向量机进行鸢尾花二元分类

Beginner

This tutorial is from open-source community. Access the source code

简介

本教程将介绍使用不同的支持向量机(SVM)内核进行分类的过程。我们将使用鸢尾花数据集,其中包含花朵的测量数据。该数据集有三个类别,但我们将仅使用其中两个类别进行二分类。

虚拟机使用提示

虚拟机启动完成后,点击左上角切换到“笔记本”标签页,以访问 Jupyter Notebook 进行练习。

有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。

如果你在学习过程中遇到问题,随时向 Labby 提问。课程结束后提供反馈,我们将立即为你解决问题。

加载数据

我们将首先加载鸢尾花数据集,并仅选择前两个特征用于可视化。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets, svm

iris = datasets.load_iris()
X = iris.data
y = iris.target

X = X[y!= 0, :2]
y = y[y!= 0]

准备数据

接下来,我们将准备用于训练和测试的数据。我们会将数据分成 90% 用于训练,10% 用于测试。

n_sample = len(X)

np.random.seed(0)
order = np.random.permutation(n_sample)
X = X[order]
y = y[order].astype(float)

X_train = X[: int(0.9 * n_sample)]
y_train = y[: int(0.9 * n_sample)]
X_test = X[int(0.9 * n_sample) :]
y_test = y[int(0.9 * n_sample) :]

使用不同内核训练模型

现在我们将使用三种不同的内核训练支持向量机(SVM)模型:线性(linear)、径向基函数(rbf)和多项式(poly)。对于每个内核,我们将模型拟合到训练数据上,绘制决策边界,并展示在测试数据上的准确率。

## fit the model
for kernel in ("linear", "rbf", "poly"):
    clf = svm.SVC(kernel=kernel, gamma=10)
    clf.fit(X_train, y_train)

    plt.figure()
    plt.clf()
    plt.scatter(
        X[:, 0], X[:, 1], c=y, zorder=10, cmap=plt.cm.Paired, edgecolor="k", s=20
    )

    ## Circle out the test data
    plt.scatter(
        X_test[:, 0], X_test[:, 1], s=80, facecolors="none", zorder=10, edgecolor="k"
    )

    plt.axis("tight")
    x_min = X[:, 0].min()
    x_max = X[:, 0].max()
    y_min = X[:, 1].min()
    y_max = X[:, 1].max()

    XX, YY = np.mgrid[x_min:x_max:200j, y_min:y_max:200j]
    Z = clf.decision_function(np.c_[XX.ravel(), YY.ravel()])

    ## Put the result into a color plot
    Z = Z.reshape(XX.shape)
    plt.pcolormesh(XX, YY, Z > 0, cmap=plt.cm.Paired)
    plt.contour(
        XX,
        YY,
        Z,
        colors=["k", "k", "k"],
        linestyles=["--", "-", "--"],
        levels=[-0.5, 0, 0.5],
    )

    plt.title(kernel)
    plt.show()

    print(f"Accuracy with {kernel} kernel: {clf.score(X_test, y_test)}")

解读结果

我们可以看到,线性内核产生一个线性决策边界,而径向基函数(rbf)和多项式(poly)内核产生更复杂的边界。在测试数据上,径向基函数内核的准确率最高,其次是多项式内核,然后是线性内核。

总结

在本教程中,我们学习了如何使用不同的支持向量机(SVM)内核进行分类。我们使用三种不同的内核训练了一个 SVM 模型,并可视化了每个内核的决策边界。我们还计算了每个内核在测试数据上的准确率。我们发现,对于鸢尾花数据集,径向基函数(rbf)内核产生了最佳结果。