Create Confusion Matrix

PythonPythonBeginner
Practice Now

Introduction

In this project, you will learn how to implement a confusion matrix, which is a fundamental tool for evaluating the performance of a classification model. The confusion matrix provides a detailed breakdown of the model's predictions, allowing you to identify areas for improvement and gain valuable insights into the model's strengths and weaknesses.

๐ŸŽฏ Tasks

In this project, you will learn:

  • How to implement the confusion_matrix function to compute the confusion matrix for a classification problem
  • How to test and refine the confusion_matrix function to handle edge cases and improve its robustness
  • How to document the confusion_matrix function to make it more user-friendly and easier to understand
  • How to integrate the confusion_matrix function into a larger machine learning project and use it to evaluate the performance of a classification model

๐Ÿ† Achievements

After completing this project, you will be able to:

  • Compute and interpret the confusion matrix for a classification problem
  • Apply techniques for handling edge cases and improving the robustness of a function
  • Implement best practices for documenting and making code more user-friendly
  • Apply the confusion matrix in the context of a larger machine learning project

Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL python(("`Python`")) -.-> python/ControlFlowGroup(["`Control Flow`"]) python(("`Python`")) -.-> python/DataStructuresGroup(["`Data Structures`"]) python(("`Python`")) -.-> python/FunctionsGroup(["`Functions`"]) python(("`Python`")) -.-> python/DataScienceandMachineLearningGroup(["`Data Science and Machine Learning`"]) python/ControlFlowGroup -.-> python/conditional_statements("`Conditional Statements`") python/DataStructuresGroup -.-> python/lists("`Lists`") python/FunctionsGroup -.-> python/function_definition("`Function Definition`") python/DataScienceandMachineLearningGroup -.-> python/data_analysis("`Data Analysis`") subgraph Lab Skills python/conditional_statements -.-> lab-300201{{"`Create Confusion Matrix`"}} python/lists -.-> lab-300201{{"`Create Confusion Matrix`"}} python/function_definition -.-> lab-300201{{"`Create Confusion Matrix`"}} python/data_analysis -.-> lab-300201{{"`Create Confusion Matrix`"}} end

Implement the Confusion Matrix Function

In this step, you will implement the confusion_matrix function in the confusion_matrix.py file. This function will compute the confusion matrix for a classification problem.

The confusion_matrix function takes three inputs:

  1. labels: A list of labels representing the different classes.
  2. preds: A list of predictions, where each prediction is a list of probabilities corresponding to the classes in the labels list.
  3. ground_truth: A list of ground truth labels.

The function should return the confusion matrix as a list of lists, where each inner list represents a row in the matrix.

Here's the starter code for the confusion_matrix function:

def confusion_matrix(
    labels: List, preds: List[List[float]], ground_truth: List
) -> List[List[int]]:
    """
    Compute the confusion matrix for a classification problem.

    The function takes a list of labels, a list of predictions (each as a list of probabilities
    for each class), and a list of ground truth labels, and returns a confusion matrix.
    The confusion matrix is a square matrix where entry (i, j) is the number of times class i
    was predicted when the true class was j.

    Parameters:
    labels (List): A list of labels representing the different classes.
    preds (List[List[float]]): A list of predictions where each prediction is a list of
                               probabilities corresponding to the classes in the labels list.
    ground_truth (List): A list of ground truth labels.

    Returns:
    List[List[int]]: The confusion matrix represented as a list of lists where each list
                     represents a row in the matrix.
    """
    ## This creates a square matrix with dimensions equal to the number of classes, initializing all elements to zero. Each row and column corresponds to a class label.
    matrix = [[0 for _ in range(len(labels))] for _ in range(len(labels))]

    ## This loop pairs each prediction with its corresponding ground truth label and processes them one by one.
    for pred, truth in zip(preds, ground_truth):
        ## Uses NumPy to find the index of the highest probability in the prediction list, which corresponds to the predicted class.
        pred_index = np.argmax(pred)
        ## Finds the index of the true class label in the `labels` list.
        truth_index = labels.index(truth)
        ## This line increments the cell at the intersection of the predicted class row and the true class column in the confusion matrix, effectively counting the occurrence of this specific prediction-truth pair.
        matrix[pred_index][truth_index] += 1

    ## After processing all predictions, the function returns the computed confusion matrix.
    return matrix

In the confusion_matrix function, you implement the logic to compute the confusion matrix for a classification problem.

โœจ Check Solution and Practice

Test the Confusion Matrix Function

In this step, you will test the confusion_matrix function using the provided example.

Add the following code in the confusion_matrix.py file:

if __name__ == "__main__":
    labels = ["Python", "Java", "C++"]
    preds = [
        [0.66528198, 0.21971853, 0.11499949],
        [0.34275858, 0.05847305, 0.59876836],
        [0.47650585, 0.26353373, 0.25996042],
        [0.76153846, 0.15384615, 0.08461538],
        [0.04691943, 0.9478673, 0.00521327],
    ]
    ground_truth = ["Python", "C++", "Java", "C++", "Java"]
    matrix = confusion_matrix(labels, preds, ground_truth)
    print(matrix)

Run the confusion_matrix.py file to execute the example:

python confusion_matrix.py

The output should be:

[[1, 1, 1],
 [0, 1, 0],
 [0, 0, 1]]

Verify that the output matches the expected confusion matrix.

If the output is not as expected, review the implementation of the confusion_matrix function and make any necessary corrections.

โœจ Check Solution and Practice

Summary

Congratulations! You have completed this project. You can practice more labs in LabEx to improve your skills.

Other Python Tutorials you may like