简单手写字符识别分类器

PythonBeginner
立即练习

介绍

在这个项目中,你将学习如何使用 scikit-learn 库提供的 DIGITS 数据集构建一个简单的手写字符识别分类器。手写字符识别是机器学习中的一个经典问题,本项目将指导你完成创建一个能够准确预测手写字符图像中所表示数字的分类器的过程。

🎯 任务

在这个项目中,你将学习:

  • 如何加载 DIGITS 数据集并将其拆分为训练集和测试集
  • 如何在训练数据上创建并训练支持向量机(SVM)分类器
  • 如何实现一个函数来对手写字符图像进行分类
  • 如何使用示例手写字符图像测试分类器

🏆 成果

完成本项目后,你将能够:

  • 加载和预处理用于机器学习任务的数据集
  • 使用 scikit-learn 创建并训练 SVM 分类器
  • 实现一个预测函数来对新样本进行分类
  • 理解使用机器学习技术进行手写字符识别的基础知识

加载数字数据集

在这一步中,你将学习如何从 scikit-learn 库中加载 DIGITS 数据集。按照以下步骤完成此步骤:

打开 handwritten_digit_classifier.py 文件,导入必要的库:

from sklearn import datasets
from sklearn.model_selection import train_test_split

使用 datasets.load_digits() 函数加载 DIGITS 数据集:

digits = datasets.load_digits()
X, y = digits.data, digits.target

X 变量包含扁平化的 8x8 像素图像,y 变量包含相应的数字标签(0 - 9)。

使用 train_test_split() 将数据集拆分为训练集和测试集:

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

这将把数据拆分为 80% 的训练集和 20% 的测试集。

创建并训练支持向量机分类器

在这一步中,你将学习如何在训练数据上创建并训练一个支持向量机(SVM)分类器。按照以下步骤完成此步骤:

handwritten_digit_classifier.py 文件中,从 sklearn.svm 模块导入 SVC 类:

from sklearn.svm import SVC

创建一个具有线性核和正则化参数为 1 的 SVM 分类器:

clf = SVC(kernel="linear", C=1)

使用 fit() 方法在训练数据上训练 SVM 分类器:

clf.fit(X_train, y_train)

这将在训练数据上训练 SVM 分类器。

实现预测函数

在这一步中,你将实现 predict(sample) 函数,用于对手写字符图像进行分类。按照以下步骤完成此步骤:

handwritten_digit_classifier.py 文件中导入 numpy 模块:

import numpy as np

定义 predict(sample) 函数:

def predict(sample):
    """
    参数:
    sample -- 手写字符图像的像素值列表

    返回:
    pred -- 作为整数的手写字符图像的预测标签
    """
    ## 重塑输入样本
    sample = np.array(sample).reshape(1, -1)

    ## 使用训练好的分类器进行预测
    pred = clf.predict(sample)

    return int(pred[0])

predict(sample) 函数中:

  • 将输入的 sample 列表转换为 NumPy 数组,并将其重塑为具有与训练数据相同格式的单个样本。
  • 使用训练好的 clf 分类器,通过 predict() 方法预测重塑后的输入样本的标签。
  • 将预测标签作为整数返回。

测试分类器

现在你可以使用一个示例手写字符图像来测试 predict(sample) 函数。以下是 handwritten_digit_classifier.py 文件中的一个示例:

sample = [
    0.0, 0.0, 6.0, 14.0, 4.0, 0.0, 0.0, 0.0,
    0.0, 0.0, 11.0, 16.0, 10.0, 0.0, 0.0, 0.0,
    0.0, 0.0, 8.0, 14.0, 16.0, 2.0, 0.0, 0.0,
    0.0, 0.0, 1.0, 12.0, 12.0, 11.0, 0.0, 0.0,
    0.0, 0.0, 0.0, 0.0, 11.0, 3.0, 0.0, 0.0,
    0.0, 0.0, 0.0, 0.0, 5.0, 11.0, 0.0, 0.0,
    0.0, 1.0, 4.0, 4.0, 7.0, 16.0, 2.0, 0.0,
    0.0, 7.0, 16.0, 16.0, 13.0, 11.0, 1.0, 0.0
]

result = predict(sample)
print("Predicted Label:", result)

这应该输出给定手写字符图像的预测标签。

运行 handwritten_digit_classifier.py 文件以执行该示例:

python handwritten_digit_classifier.py
## 预测标签: 9

总结

恭喜你!你已经完成了这个项目。你可以在 LabEx 中练习更多实验来提升你的技能。

✨ 查看解决方案并练习✨ 查看解决方案并练习✨ 查看解决方案并练习✨ 查看解决方案并练习