创建混淆矩阵

PythonBeginner
立即练习

介绍

在这个项目中,你将学习如何实现混淆矩阵,这是评估分类模型性能的一个基本工具。混淆矩阵详细地展示了模型的预测结果,使你能够确定需要改进的地方,并深入了解模型的优缺点。

🎯 任务

在这个项目中,你将学习:

  • 如何实现 confusion_matrix 函数来计算分类问题的混淆矩阵
  • 如何测试和优化 confusion_matrix 函数以处理边界情况并提高其鲁棒性
  • 如何记录 confusion_matrix 函数,使其更便于用户使用和理解
  • 如何将 confusion_matrix 函数集成到更大的机器学习项目中,并使用它来评估分类模型的性能

🏆 成果

完成这个项目后,你将能够:

  • 计算并解释分类问题的混淆矩阵
  • 应用处理边界情况和提高函数鲁棒性的技术
  • 实施记录代码并使其更便于用户使用的最佳实践
  • 在更大的机器学习项目中应用混淆矩阵

实现混淆矩阵函数

在这一步中,你将在 confusion_matrix.py 文件中实现 confusion_matrix 函数。该函数将计算分类问题的混淆矩阵。

confusion_matrix 函数接受三个输入:

  1. labels:表示不同类别的标签列表。
  2. preds:预测列表,其中每个预测是与 labels 列表中的类别相对应的概率列表。
  3. ground_truth:真实标签列表。

该函数应返回混淆矩阵,以列表的列表形式表示,其中每个内部列表表示矩阵中的一行。

以下是 confusion_matrix 函数的起始代码:

def confusion_matrix(
    labels: List, preds: List[List[float]], ground_truth: List
) -> List[List[int]]:
    """
    计算分类问题的混淆矩阵。

    该函数接受一个标签列表、一个预测列表(每个预测是每个类别的概率列表)和一个真实标签列表,并返回一个混淆矩阵。
    混淆矩阵是一个方阵,其中 (i, j) 项是当真实类别为 j 时,类别 i 被预测的次数。

    参数:
    labels (List):表示不同类别的标签列表。
    preds (List[List[float]]):预测列表,其中每个预测是与 labels 列表中的类别相对应的概率列表。
    ground_truth (List):真实标签列表。

    返回:
    List[List[int]]:以列表的列表形式表示的混淆矩阵,其中每个列表表示矩阵中的一行。
    """
    ## 这创建了一个维度等于类别数量的方阵,将所有元素初始化为零。每行和每列对应一个类别标签。
    matrix = [[0 for _ in range(len(labels))] for _ in range(len(labels))]

    ## 这个循环将每个预测与其对应的真实标签配对,并逐个进行处理。
    for pred, truth in zip(preds, ground_truth):
        ## 使用 NumPy 找到预测列表中概率最高的索引,该索引对应于预测的类别。
        pred_index = np.argmax(pred)
        ## 在 `labels` 列表中找到真实类别标签的索引。
        truth_index = labels.index(truth)
        ## 这行代码增加了混淆矩阵中预测类别行和真实类别列交叉处的单元格的值,有效地统计了这个特定预测 - 真实标签对的出现次数。
        matrix[pred_index][truth_index] += 1

    ## 处理完所有预测后,函数返回计算出的混淆矩阵。
    return matrix

confusion_matrix 函数中,你实现了计算分类问题混淆矩阵的逻辑。

✨ 查看解决方案并练习

测试混淆矩阵函数

在这一步中,你将使用提供的示例来测试 confusion_matrix 函数。

confusion_matrix.py 文件中添加以下代码:

if __name__ == "__main__":
    labels = ["Python", "Java", "C++"]
    preds = [
        [0.66528198, 0.21971853, 0.11499949],
        [0.34275858, 0.05847305, 0.59876836],
        [0.47650585, 0.26353373, 0.25996042],
        [0.76153846, 0.15384615, 0.08461538],
        [0.04691943, 0.9478673, 0.00521327],
    ]
    ground_truth = ["Python", "C++", "Java", "C++", "Java"]
    matrix = confusion_matrix(labels, preds, ground_truth)
    print(matrix)

运行 confusion_matrix.py 文件以执行该示例:

python confusion_matrix.py

输出应如下所示:

[[1, 1, 1],
 [0, 1, 0],
 [0, 0, 1]]

验证输出是否与预期的混淆矩阵匹配。

如果输出与预期不符,请检查 confusion_matrix 函数的实现并进行必要的修正。

✨ 查看解决方案并练习

总结

恭喜你!你已经完成了这个项目。你可以在 LabEx 中练习更多实验来提升你的技能。