Scikit 学习混淆矩阵

Beginner

This tutorial is from open-source community. Access the source code

简介

混淆矩阵是一种用于评估分类算法性能的工具。它是一个表格,通过将预测的类别标签与实际的类别标签进行比较,总结分类模型的性能。本教程演示如何使用 scikit-learn 库生成混淆矩阵并可视化其结果。

虚拟机使用提示

虚拟机启动完成后,点击左上角切换到笔记本标签页,以访问 Jupyter Notebook 进行练习。

有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。

如果你在学习过程中遇到问题,随时向 Labby 提问。课程结束后提供反馈,我们会及时为你解决问题。

导入库

首先,我们需要导入必要的库。我们将使用 scikit-learn、matplotlib、numpy 和 datasets。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import ConfusionMatrixDisplay

加载数据

我们将使用 scikit-learn 中的鸢尾花数据集。该数据集包含 150 个样本,每个样本有四个特征和一个目标标签。

iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names

划分数据

我们会将数据集划分为训练集和测试集。训练集用于训练模型,测试集用于评估模型的性能。

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

训练模型

我们将使用线性核来训练一个支持向量机(SVM)分类器。我们会使用一个设置得很低的正则化参数 C,以便观察其对结果的影响。

classifier = svm.SVC(kernel="linear", C=0.01).fit(X_train, y_train)

生成混淆矩阵

我们将使用 scikit-learn 中的 ConfusionMatrixDisplay 类生成一个混淆矩阵。该混淆矩阵会展示每个类别的正确预测和错误预测的数量。

np.set_printoptions(precision=2)
disp = ConfusionMatrixDisplay.from_estimator(
    classifier,
    X_test,
    y_test,
    display_labels=class_names,
    cmap=plt.cm.Blues,
    normalize=None,
)

可视化混淆矩阵

我们将使用 matplotlib 来可视化混淆矩阵。我们会绘制一个未归一化的混淆矩阵和一个归一化的混淆矩阵。

titles_options = [
    ("Confusion matrix, without normalization", None),
    ("Normalized confusion matrix", "true"),
]
for title, normalize in titles_options:
    disp = ConfusionMatrixDisplay.from_estimator(
        classifier,
        X_test,
        y_test,
        display_labels=class_names,
        cmap=plt.cm.Blues,
        normalize=normalize,
    )
    disp.ax_.set_title(title)
    print(title)
    print(disp.confusion_matrix)
plt.show()

总结

在本教程中,我们学习了如何使用 scikit-learn 库来生成混淆矩阵并可视化其结果。我们加载了鸢尾花数据集,将其拆分为训练集和测试集,训练了一个支持向量机分类器,并生成和可视化了一个混淆矩阵。混淆矩阵展示了每个类别的正确预测和错误预测的数量,而可视化则帮助我们理解结果。