使用多层感知器分类器对手写数字进行分类

Beginner

This tutorial is from open-source community. Access the source code

简介

本教程将向你展示如何使用 Scikit-learn 创建一个 MLPClassifier,以便对 MNIST 数据集中的手写数字进行分类。我们还将可视化 MLP 第一层的权重,以深入了解学习行为。

虚拟机提示

虚拟机启动完成后,点击左上角切换到“笔记本”标签,以访问 Jupyter Notebook 进行练习。

有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。

如果你在学习过程中遇到问题,请随时向 Labby 提问。课程结束后提供反馈,我们将立即为你解决问题。

导入库

我们将首先导入此项目所需的库。

import warnings
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.exceptions import ConvergenceWarning
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split

加载数据

接下来,我们将使用 Scikit-learn 的 fetch_openml 函数加载 MNIST 数据集。

X, y = fetch_openml(
    "mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas"
)

数据预处理

我们将通过将每个像素值除以 255.0(这是最大像素值)来对数据进行归一化处理。

X = X / 255.0

划分数据

我们将使用 train_test_split 函数把数据集划分为训练集和测试集。

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0, test_size=0.7)

训练多层感知器分类器(MLPClassifier)

我们将创建一个具有单个隐藏层且包含 40 个神经元的多层感知器分类器。由于资源限制,我们将仅训练该多层感知器 8 次迭代。我们还将捕获由于模型在有限的迭代次数内无法收敛而抛出的 ConvergenceWarning(收敛警告)。

mlp = MLPClassifier(
    hidden_layer_sizes=(40,),
    max_iter=8,
    alpha=1e-4,
    solver="sgd",
    verbose=10,
    random_state=1,
    learning_rate_init=0.2,
)

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=ConvergenceWarning, module="sklearn")
    mlp.fit(X_train, y_train)

评估模型

我们将通过计算多层感知器分类器(MLPClassifier)在训练集和测试集上的准确率来评估它。

print("Training set score: %f" % mlp.score(X_train, y_train))
print("Test set score: %f" % mlp.score(X_test, y_test))

可视化权重

最后,我们将可视化多层感知器(MLP)第一层的权重。我们将创建一个 4x4 的子图网格,并将每个权重显示为一个 28x28 像素的灰度图像。

fig, axes = plt.subplots(4, 4)
vmin, vmax = mlp.coefs_[0].min(), mlp.coefs_[0].max()
for coef, ax in zip(mlp.coefs_[0].T, axes.ravel()):
    ax.matshow(coef.reshape(28, 28), cmap=plt.cm.gray, vmin=0.5 * vmin, vmax=0.5 * vmax)
    ax.set_xticks(())
    ax.set_yticks(())

plt.show()

总结

在本教程中,我们学习了如何使用 Scikit-learn 创建一个多层感知器分类器(MLPClassifier)来对 MNIST 数据集中的手写数字进行分类。我们还可视化了多层感知器第一层的权重,以深入了解其学习行为。