MNIST 多项式逻辑回归

Machine LearningMachine LearningBeginner
立即练习

This tutorial is from open-source community. Access the source code

💡 本教程由 AI 辅助翻译自英文原版。如需查看原文,您可以 切换至英文原版

简介

在本实验中,我们将学习如何使用逻辑回归算法对 MNIST 数据集中的手写数字进行分类。我们将使用 SAGA 算法在 MNIST 数字分类任务的一个子集上拟合带有 L1 惩罚的多项逻辑回归。

虚拟机使用提示

虚拟机启动完成后,点击左上角切换到笔记本标签页,以访问 Jupyter Notebook 进行练习。

有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。

如果你在学习过程中遇到问题,随时向 Labby 提问。课程结束后提供反馈,我们会及时为你解决问题。


Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL ml(("Machine Learning")) -.-> ml/FrameworkandSoftwareGroup(["Framework and Software"]) sklearn(("Sklearn")) -.-> sklearn/CoreModelsandAlgorithmsGroup(["Core Models and Algorithms"]) sklearn(("Sklearn")) -.-> sklearn/DataPreprocessingandFeatureEngineeringGroup(["Data Preprocessing and Feature Engineering"]) sklearn(("Sklearn")) -.-> sklearn/ModelSelectionandEvaluationGroup(["Model Selection and Evaluation"]) sklearn(("Sklearn")) -.-> sklearn/UtilitiesandDatasetsGroup(["Utilities and Datasets"]) sklearn/CoreModelsandAlgorithmsGroup -.-> sklearn/linear_model("Linear Models") sklearn/DataPreprocessingandFeatureEngineeringGroup -.-> sklearn/preprocessing("Preprocessing and Normalization") sklearn/ModelSelectionandEvaluationGroup -.-> sklearn/model_selection("Model Selection") sklearn/UtilitiesandDatasetsGroup -.-> sklearn/utils("Utilities") sklearn/UtilitiesandDatasetsGroup -.-> sklearn/datasets("Datasets") ml/FrameworkandSoftwareGroup -.-> ml/sklearn("scikit-learn") subgraph Lab Skills sklearn/linear_model -.-> lab-49297{{"MNIST 多项式逻辑回归"}} sklearn/preprocessing -.-> lab-49297{{"MNIST 多项式逻辑回归"}} sklearn/model_selection -.-> lab-49297{{"MNIST 多项式逻辑回归"}} sklearn/utils -.-> lab-49297{{"MNIST 多项式逻辑回归"}} sklearn/datasets -.-> lab-49297{{"MNIST 多项式逻辑回归"}} ml/sklearn -.-> lab-49297{{"MNIST 多项式逻辑回归"}} end

导入库

我们将首先为本实验导入必要的库。我们将使用 scikit-learn 库来获取数据集、训练模型并评估模型的性能。

import time
import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state

加载 MNIST 数据集

我们将使用 scikit-learn 中的fetch_openml函数来加载 MNIST 数据集。我们还将通过将train_samples的数量设置为 5000 来选择数据的一个子集。

## Turn down for faster convergence
t0 = time.time()
train_samples = 5000

## Load data from https://www.openml.org/d/554
X, y = fetch_openml(
    "mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas"
)

数据预处理

我们将通过打乱数据、将数据集拆分为训练集和测试集,以及使用StandardScaler对数据进行缩放来预处理数据。

random_state = check_random_state(0)
permutation = random_state.permutation(X.shape[0])
X = X[permutation]
y = y[permutation]
X = X.reshape((X.shape[0], -1))

X_train, X_test, y_train, y_test = train_test_split(
    X, y, train_size=train_samples, test_size=10000
)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

训练模型

我们将使用带有 L1 惩罚的逻辑回归和 SAGA 算法来训练模型。我们将把C的值设置为 50.0 除以训练样本的数量。

## Turn up tolerance for faster convergence
clf = LogisticRegression(C=50.0 / train_samples, penalty="l1", solver="saga", tol=0.1)
clf.fit(X_train, y_train)

评估模型

我们将通过计算稀疏性和准确率得分来评估模型的性能。

sparsity = np.mean(clf.coef_ == 0) * 100
score = clf.score(X_test, y_test)

print("Sparsity with L1 penalty: %.2f%%" % sparsity)
print("Test score with L1 penalty: %.4f" % score)

可视化模型

我们将通过绘制每个类别的分类向量来可视化模型。

coef = clf.coef_.copy()
plt.figure(figsize=(10, 5))
scale = np.abs(coef).max()
for i in range(10):
    l1_plot = plt.subplot(2, 5, i + 1)
    l1_plot.imshow(
        coef[i].reshape(28, 28),
        interpolation="nearest",
        cmap=plt.cm.RdBu,
        vmin=-scale,
        vmax=scale,
    )
    l1_plot.set_xticks(())
    l1_plot.set_yticks(())
    l1_plot.set_xlabel("Class %i" % i)
plt.suptitle("Classification vector for...")

run_time = time.time() - t0
print("Example run in %.3f s" % run_time)
plt.show()

总结

在本实验中,我们学习了如何使用逻辑回归对 MNIST 数据集中的手写数字进行分类。我们还学习了如何将带有 L1 惩罚的 SAGA 算法用于逻辑回归。我们通过稀疏权重向量实现了超过 0.8 的准确率得分,这使得模型更易于解释。然而,我们也注意到,在这个数据集上,这个准确率明显低于 L2 惩罚线性模型或非线性多层感知器模型所能达到的准确率。