Scikit - 学习多类随机梯度下降分类器

Beginner

This tutorial is from open-source community. Access the source code

简介

在本实验中,我们将使用 Scikit-Learn 的 SGDClassifier 在著名的鸢尾花数据集上实现一个多类分类模型。我们将在数据集上绘制模型的决策面,并可视化对应于三个一对多(OVA)分类器的超平面。

虚拟机使用提示

虚拟机启动完成后,点击左上角切换到“笔记本”标签页,以访问 Jupyter Notebook 进行练习。

有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。

如果你在学习过程中遇到问题,随时向 Labby 提问。课程结束后提供反馈,我们会及时为你解决问题。

加载并准备数据

我们首先导入必要的库并加载鸢尾花数据集。然后,我们将对数据进行洗牌(打乱顺序)并标准化,以便用于训练。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.linear_model import SGDClassifier

## 加载鸢尾花数据集
iris = datasets.load_iris()

## 取前两个特征
X = iris.data[:, :2]
y = iris.target
colors = "bry"

## 打乱数据
idx = np.arange(X.shape[0])
np.random.seed(13)
np.random.shuffle(idx)
X = X[idx]
y = y[idx]

## 标准化数据
mean = X.mean(axis=0)
std = X.std(axis=0)
X = (X - mean) / std

训练模型

现在,我们将借助 fit() 方法在鸢尾花数据集上训练 SGDClassifier 模型。此方法将输入数据和目标值作为输入,并在给定数据上训练模型。

clf = SGDClassifier(alpha=0.001, max_iter=100).fit(X, y)

可视化决策面

现在,我们将在鸢尾花数据集上绘制训练好的模型的决策面。我们将使用 DecisionBoundaryDisplay 类来可视化模型的决策边界。

from sklearn.inspection import DecisionBoundaryDisplay

ax = plt.gca()
DecisionBoundaryDisplay.from_estimator(
    clf,
    X,
    cmap=plt.cm.Paired,
    ax=ax,
    response_method="predict",
    xlabel=iris.feature_names[0],
    ylabel=iris.feature_names[1],
)
plt.axis("tight")

绘制训练点

现在我们将在决策面上绘制训练点。我们将使用 scatter() 方法为不同的目标值绘制不同颜色的训练点。

for i, color in zip(clf.classes_, colors):
    idx = np.where(y == i)
    plt.scatter(
        X[idx, 0],
        X[idx, 1],
        c=color,
        label=iris.target_names[i],
        cmap=plt.cm.Paired,
        edgecolor="black",
        s=20,
    )
plt.title("Decision surface of multi-class SGD")
plt.axis("tight")

绘制一对多分类器

现在我们将在决策面上绘制三个一对多(OVA)分类器。我们将使用训练好的模型的 coef_intercept_ 属性来绘制与 OVA 分类器相对应的超平面。

xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()
coef = clf.coef_
intercept = clf.intercept_


def plot_hyperplane(c, color):
    def line(x0):
        return (-(x0 * coef[c, 0]) - intercept[c]) / coef[c, 1]

    plt.plot([xmin, xmax], [line(xmin), line(xmax)], ls="--", color=color)


for i, color in zip(clf.classes_, colors):
    plot_hyperplane(i, color)
plt.legend()
plt.show()

总结

在本实验中,我们学习了如何使用 Scikit-Learn 的 SGDClassifier 在鸢尾花数据集上实现多类分类模型。我们在数据集上可视化了训练好的模型的决策面,并绘制了与三个一对多(OVA)分类器相对应的超平面。