Matplotlib Color Vision Deficiency Simulation

PythonPythonBeginner
Practice Now

This tutorial is from open-source community. Access the source code

Introduction

Matplotlib is a popular data visualization library in Python. It has various built-in features, including the ability to simulate color vision deficiencies. This lab will guide you through the steps of using Matplotlib to simulate color vision deficiencies.

VM Tips

After the VM startup is done, click the top left corner to switch to the Notebook tab to access Jupyter Notebook for practice.

Sometimes, you may need to wait a few seconds for Jupyter Notebook to finish loading. The validation of operations cannot be automated because of limitations in Jupyter Notebook.

If you face issues during learning, feel free to ask Labby. Provide feedback after the session, and we will promptly resolve the problem for you.


Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL python(("`Python`")) -.-> python/BasicConceptsGroup(["`Basic Concepts`"]) python(("`Python`")) -.-> python/FileHandlingGroup(["`File Handling`"]) matplotlib(("`Matplotlib`")) -.-> matplotlib/BasicConceptsGroup(["`Basic Concepts`"]) matplotlib(("`Matplotlib`")) -.-> matplotlib/PlottingDataGroup(["`Plotting Data`"]) matplotlib(("`Matplotlib`")) -.-> matplotlib/PlotCustomizationGroup(["`Plot Customization`"]) matplotlib(("`Matplotlib`")) -.-> matplotlib/AdvancedTopicsGroup(["`Advanced Topics`"]) python(("`Python`")) -.-> python/ControlFlowGroup(["`Control Flow`"]) python(("`Python`")) -.-> python/DataStructuresGroup(["`Data Structures`"]) python(("`Python`")) -.-> python/FunctionsGroup(["`Functions`"]) python(("`Python`")) -.-> python/ModulesandPackagesGroup(["`Modules and Packages`"]) python(("`Python`")) -.-> python/ErrorandExceptionHandlingGroup(["`Error and Exception Handling`"]) python(("`Python`")) -.-> python/DataScienceandMachineLearningGroup(["`Data Science and Machine Learning`"]) python/BasicConceptsGroup -.-> python/comments("`Comments`") python/FileHandlingGroup -.-> python/with_statement("`Using with Statement`") matplotlib/BasicConceptsGroup -.-> matplotlib/importing_matplotlib("`Importing Matplotlib`") matplotlib/BasicConceptsGroup -.-> matplotlib/figures_axes("`Understanding Figures and Axes`") matplotlib/PlottingDataGroup -.-> matplotlib/line_plots("`Line Plots`") matplotlib/PlottingDataGroup -.-> matplotlib/heatmaps("`Heatmaps`") matplotlib/PlotCustomizationGroup -.-> matplotlib/legend_config("`Legend Configuration`") matplotlib/AdvancedTopicsGroup -.-> matplotlib/custom_backends("`Custom Backends`") matplotlib/AdvancedTopicsGroup -.-> matplotlib/matplotlib_config("`Customizing Matplotlib Configurations`") python/BasicConceptsGroup -.-> python/variables_data_types("`Variables and Data Types`") python/BasicConceptsGroup -.-> python/strings("`Strings`") python/ControlFlowGroup -.-> python/conditional_statements("`Conditional Statements`") python/ControlFlowGroup -.-> python/for_loops("`For Loops`") python/ControlFlowGroup -.-> python/break_continue("`Break and Continue`") python/DataStructuresGroup -.-> python/lists("`Lists`") python/DataStructuresGroup -.-> python/tuples("`Tuples`") python/DataStructuresGroup -.-> python/dictionaries("`Dictionaries`") python/DataStructuresGroup -.-> python/sets("`Sets`") python/FunctionsGroup -.-> python/function_definition("`Function Definition`") python/ModulesandPackagesGroup -.-> python/importing_modules("`Importing Modules`") python/ModulesandPackagesGroup -.-> python/using_packages("`Using Packages`") python/ModulesandPackagesGroup -.-> python/standard_libraries("`Common Standard Libraries`") python/ErrorandExceptionHandlingGroup -.-> python/raising_exceptions("`Raising Exceptions`") python/DataScienceandMachineLearningGroup -.-> python/numerical_computing("`Numerical Computing`") python/DataScienceandMachineLearningGroup -.-> python/data_visualization("`Data Visualization`") python/FunctionsGroup -.-> python/build_in_functions("`Build-in Functions`") subgraph Lab Skills python/comments -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/with_statement -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} matplotlib/importing_matplotlib -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} matplotlib/figures_axes -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} matplotlib/line_plots -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} matplotlib/heatmaps -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} matplotlib/legend_config -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} matplotlib/custom_backends -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} matplotlib/matplotlib_config -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/variables_data_types -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/strings -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/conditional_statements -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/for_loops -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/break_continue -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/lists -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/tuples -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/dictionaries -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/sets -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/function_definition -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/importing_modules -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/using_packages -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/standard_libraries -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/raising_exceptions -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/numerical_computing -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/data_visualization -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} python/build_in_functions -.-> lab-48832{{"`Matplotlib Color Vision Deficiency Simulation`"}} end

Import necessary libraries and modules

First, we need to import the necessary libraries and modules, including Matplotlib, NumPy, and colorspacious. We also set the color filter options that we want to simulate.

import functools
from pathlib import Path

import colorspacious

import numpy as np

_BUTTON_NAME = "Filter"
_BUTTON_HELP = "Simulate color vision deficiencies"
_MENU_ENTRIES = {
    "None": None,
    "Greyscale": "greyscale",
    "Deuteranopia": "deuteranomaly",
    "Protanopia": "protanomaly",
    "Tritanopia": "tritanomaly",
}

Define the color filter function

Next, we define a function that creates a color filter function based on the color filter name. This function uses the colorspacious module to convert the input image to a different color space based on the color filter name.

def _get_color_filter(name):
    """
    Given a color filter name, create a color filter function.

    Parameters
    ----------
    name : str
        The color filter name, one of the following:

        - ``"none"``: ...
        - ``"greyscale"``: Convert the input to luminosity.
        - ``"deuteranopia"``: Simulate the most common form of red-green
          colorblindness.
        - ``"protanopia"``: Simulate a rarer form of red-green colorblindness.
        - ``"tritanopia"``: Simulate the rare form of blue-yellow
          colorblindness.

        Color conversions use `colorspacious`_.

    Returns
    -------
    callable
        A color filter function that has the form:

        def filter(input: np.ndarray[M, N, D])-> np.ndarray[M, N, D]

        where (M, N) are the image dimensions, and D is the color depth (3 for
        RGB, 4 for RGBA). Alpha is passed through unchanged and otherwise
        ignored.
    """
    if name not in _MENU_ENTRIES:
        raise ValueError(f"Unsupported filter name: {name!r}")
    name = _MENU_ENTRIES[name]

    if name is None:
        return None

    elif name == "greyscale":
        rgb_to_jch = colorspacious.cspace_converter("sRGB1", "JCh")
        jch_to_rgb = colorspacious.cspace_converter("JCh", "sRGB1")

        def convert(im):
            greyscale_JCh = rgb_to_jch(im)
            greyscale_JCh[..., 1] = 0
            im = jch_to_rgb(greyscale_JCh)
            return im

    else:
        cvd_space = {"name": "sRGB1+CVD", "cvd_type": name, "severity": 100}
        convert = colorspacious.cspace_converter(cvd_space, "sRGB1")

    def filter_func(im, dpi):
        alpha = None
        if im.shape[-1] == 4:
            im, alpha = im[..., :3], im[..., 3]
        im = convert(im)
        if alpha is not None:
            im = np.dstack((im, alpha))
        return np.clip(im, 0, 1), 0, 0

    return filter_func

Set menu entry

We define a function that sets the menu entry based on the selected color filter name. This function updates the color filter function based on the selection.

def _set_menu_entry(tb, name):
    tb.canvas.figure.set_agg_filter(_get_color_filter(name))
    tb.canvas.draw_idle()

Set up the toolbar

Next, we define a function that sets up the toolbar based on the type of backend used. This function creates a button that allows the user to select the type of color filter to simulate.

def setup(figure):
    tb = figure.canvas.toolbar
    if tb is None:
        return
    for cls in type(tb).__mro__:
        pkg = cls.__module__.split(".")[0]
        if pkg != "matplotlib":
            break
    if pkg == "gi":
        _setup_gtk(tb)
    elif pkg in ("PyQt5", "PySide2", "PyQt6", "PySide6"):
        _setup_qt(tb)
    elif pkg == "tkinter":
        _setup_tk(tb)
    elif pkg == "wx":
        _setup_wx(tb)
    else:
        raise NotImplementedError("The current backend is not supported")

Create sample images

We create sample images to demonstrate the color filter function. We import a sample image of Grace Hopper and plot it using Matplotlib. We also create a plot of sinusoidal waves.

if __name__ == '__main__':
    import matplotlib.pyplot as plt

    from matplotlib import cbook

    plt.rcParams['figure.hooks'].append('mplcvd:setup')

    fig, axd = plt.subplot_mosaic(
        [
            ['viridis', 'turbo'],
            ['photo', 'lines']
        ]
    )

    delta = 0.025
    x = y = np.arange(-3.0, 3.0, delta)
    X, Y = np.meshgrid(x, y)
    Z1 = np.exp(-X**2 - Y**2)
    Z2 = np.exp(-(X - 1)**2 - (Y - 1)**2)
    Z = (Z1 - Z2) * 2

    imv = axd['viridis'].imshow(
        Z, interpolation='bilinear',
        origin='lower', extent=[-3, 3, -3, 3],
        vmax=abs(Z).max(), vmin=-abs(Z).max()
    )
    fig.colorbar(imv)
    imt = axd['turbo'].imshow(
        Z, interpolation='bilinear', cmap='turbo',
        origin='lower', extent=[-3, 3, -3, 3],
        vmax=abs(Z).max(), vmin=-abs(Z).max()
    )
    fig.colorbar(imt)

    ## A sample image
    with cbook.get_sample_data('grace_hopper.jpg') as image_file:
        photo = plt.imread(image_file)
    axd['photo'].imshow(photo)

    th = np.linspace(0, 2*np.pi, 1024)
    for j in [1, 2, 4, 6]:
        axd['lines'].plot(th, np.sin(th * j), label=f'$\\omega={j}$')
    axd['lines'].legend(ncols=2, loc='upper right')
    plt.show()

Summary

In this lab, we learned how to simulate color vision deficiencies using Matplotlib. We used the colorspacious module to convert the input image to a different color space based on the selected color filter name. We also created sample images to demonstrate the color filter function.

Other Python Tutorials you may like