Scikit 学习混淆矩阵

Machine LearningMachine LearningBeginner
立即练习

This tutorial is from open-source community. Access the source code

💡 本教程由 AI 辅助翻译自英文原版。如需查看原文,您可以 切换至英文原版

简介

混淆矩阵是一种用于评估分类算法性能的工具。它是一个表格,通过将预测的类别标签与实际的类别标签进行比较,总结分类模型的性能。本教程演示如何使用scikit-learn库生成混淆矩阵并可视化其结果。

虚拟机使用提示

虚拟机启动完成后,点击左上角切换到笔记本标签页,以访问Jupyter Notebook进行练习。

有时,你可能需要等待几秒钟让Jupyter Notebook完成加载。由于Jupyter Notebook的限制,操作验证无法自动化。

如果你在学习过程中遇到问题,随时向Labby提问。课程结束后提供反馈,我们会及时为你解决问题。


Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL sklearn(("Sklearn")) -.-> sklearn/ModelSelectionandEvaluationGroup(["Model Selection and Evaluation"]) ml(("Machine Learning")) -.-> ml/FrameworkandSoftwareGroup(["Framework and Software"]) sklearn/ModelSelectionandEvaluationGroup -.-> sklearn/model_selection("Model Selection") sklearn/ModelSelectionandEvaluationGroup -.-> sklearn/metrics("Metrics") ml/FrameworkandSoftwareGroup -.-> ml/sklearn("scikit-learn") subgraph Lab Skills sklearn/model_selection -.-> lab-49094{{"Scikit 学习混淆矩阵"}} sklearn/metrics -.-> lab-49094{{"Scikit 学习混淆矩阵"}} ml/sklearn -.-> lab-49094{{"Scikit 学习混淆矩阵"}} end

导入库

首先,我们需要导入必要的库。我们将使用scikit-learn、matplotlib、numpy和datasets。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import ConfusionMatrixDisplay

加载数据

我们将使用scikit-learn中的鸢尾花数据集。该数据集包含150个样本,每个样本有四个特征和一个目标标签。

iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names

划分数据

我们会将数据集划分为训练集和测试集。训练集用于训练模型,测试集用于评估模型的性能。

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

训练模型

我们将使用线性核来训练一个支持向量机(SVM)分类器。我们会使用一个设置得很低的正则化参数C,以便观察其对结果的影响。

classifier = svm.SVC(kernel="linear", C=0.01).fit(X_train, y_train)

生成混淆矩阵

我们将使用scikit-learn中的ConfusionMatrixDisplay类生成一个混淆矩阵。该混淆矩阵会展示每个类别的正确预测和错误预测的数量。

np.set_printoptions(precision=2)
disp = ConfusionMatrixDisplay.from_estimator(
    classifier,
    X_test,
    y_test,
    display_labels=class_names,
    cmap=plt.cm.Blues,
    normalize=None,
)

可视化混淆矩阵

我们将使用matplotlib来可视化混淆矩阵。我们会绘制一个未归一化的混淆矩阵和一个归一化的混淆矩阵。

titles_options = [
    ("Confusion matrix, without normalization", None),
    ("Normalized confusion matrix", "true"),
]
for title, normalize in titles_options:
    disp = ConfusionMatrixDisplay.from_estimator(
        classifier,
        X_test,
        y_test,
        display_labels=class_names,
        cmap=plt.cm.Blues,
        normalize=normalize,
    )
    disp.ax_.set_title(title)
    print(title)
    print(disp.confusion_matrix)
plt.show()

总结

在本教程中,我们学习了如何使用scikit-learn库来生成混淆矩阵并可视化其结果。我们加载了鸢尾花数据集,将其拆分为训练集和测试集,训练了一个支持向量机分类器,并生成和可视化了一个混淆矩阵。混淆矩阵展示了每个类别的正确预测和错误预测的数量,而可视化则帮助我们理解结果。