创建带注释的热图

MatplotlibMatplotlibBeginner
立即练习

This tutorial is from open-source community. Access the source code

💡 本教程由 AI 辅助翻译自英文原版。如需查看原文,您可以 切换至英文原版

简介

在数据分析中,我们经常希望将依赖于两个自变量的数据显示为一种称为热图的颜色编码图像图。在本实验中,我们将使用 Matplotlib 的 imshow 函数创建一个带有注释的热图。我们将从一个简单的示例开始,并将其扩展为一个通用函数。

虚拟机提示

虚拟机启动完成后,点击左上角切换到“笔记本”标签,以访问 Jupyter Notebook 进行练习。

有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。

如果你在学习过程中遇到问题,请随时向 Labby 提问。课程结束后提供反馈,我们将立即为你解决问题。

简单分类热图

我们将从定义一些数据开始。我们需要一个二维列表或数组来定义要进行颜色编码的数据。然后,我们还需要两个类别列表或数组。热图本身是一个 imshow 图,其标签设置为类别。我们将使用 text 函数通过在每个单元格内创建一个显示该单元格值的 Text 来为数据本身添加标签。

import matplotlib.pyplot as plt
import numpy as np

vegetables = ["cucumber", "tomato", "lettuce", "asparagus", "potato", "wheat", "barley"]
farmers = ["Farmer Joe", "Upland Bros.", "Smith Gardening", "Agrifun", "Organiculture", "BioGoods Ltd.", "Cornylee Corp."]

harvest = np.array([[0.8, 2.4, 2.5, 3.9, 0.0, 4.0, 0.0],
                    [2.4, 0.0, 4.0, 1.0, 2.7, 0.0, 0.0],
                    [1.1, 2.4, 0.8, 4.3, 1.9, 4.4, 0.0],
                    [0.6, 0.0, 0.3, 0.0, 3.1, 0.0, 0.0],
                    [0.7, 1.7, 0.6, 2.6, 2.2, 6.2, 0.0],
                    [1.3, 1.2, 0.0, 0.0, 0.0, 3.2, 5.1],
                    [0.1, 2.0, 0.0, 1.4, 0.0, 1.9, 6.3]])

fig, ax = plt.subplots()
im = ax.imshow(harvest)

## 显示所有刻度并使用相应的列表条目为其标注标签
ax.set_xticks(np.arange(len(farmers)), labels=farmers)
ax.set_yticks(np.arange(len(vegetables)), labels=vegetables)

## 旋转刻度标签并设置其对齐方式
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

## 遍历数据维度并创建文本注释
for i in range(len(vegetables)):
    for j in range(len(farmers)):
        text = ax.text(j, i, harvest[i, j], ha="center", va="center", color="w")

ax.set_title("Harvest of local farmers (in tons/year)")
fig.tight_layout()
plt.show()

使用辅助函数的代码风格

我们将创建一个函数,该函数以数据以及行和列标签作为输入,并允许使用用于自定义绘图的参数。我们将关闭周围的坐标轴边框,并创建一个白色线条的网格来分隔单元格。在这里,我们还希望创建一个颜色条,并将标签放置在热图上方而不是下方。注释应根据阈值使用不同的颜色,以便与像素颜色形成更好的对比度。

def heatmap(data, row_labels, col_labels, ax=None, cbar_kw=None, cbarlabel="", **kwargs):
    """
    从 numpy 数组和两个标签列表创建热图。

    参数
    ----------
    data
        形状为 (M, N) 的二维 numpy 数组。
    row_labels
        长度为 M 的列表或数组,包含行的标签。
    col_labels
        长度为 N 的列表或数组,包含列的标签。
    ax
        绘制热图的 `matplotlib.axes.Axes` 实例。如果未提供,则使用当前坐标轴或创建一个新的。可选。
    cbar_kw
        传递给 `matplotlib.Figure.colorbar` 的参数字典。可选。
    cbarlabel
        颜色条的标签。可选。
    **kwargs
        所有其他参数都将转发给 `imshow`。
    """

    if ax is None:
        ax = plt.gca()

    if cbar_kw is None:
        cbar_kw = {}

    ## 绘制热图
    im = ax.imshow(data, **kwargs)

    ## 创建颜色条
    cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")

    ## 显示所有刻度并用相应的列表条目为其标注标签。
    ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
    ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)

    ## 让水平坐标轴标签显示在顶部。
    ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)

    ## 旋转刻度标签并设置其对齐方式。
    plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")

    ## 关闭边框并创建白色网格。
    ax.spines[:].set_visible(False)
    ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
    ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
    ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)

    return im, cbar


def annotate_heatmap(im, data=None, valfmt="{x:.2f}", textcolors=("black", "white"), threshold=None, **textkw):
    """
    一个用于注释热图的函数。

    参数
    ----------
    im
        要标注的 AxesImage。
    data
        用于注释的数据。如果为 None,则使用图像的数据。可选。
    valfmt
        热图内注释的格式。这应该要么使用字符串格式化方法,例如 "$ {x:.2f}",要么是一个 `matplotlib.ticker.Formatter`。可选。
    textcolors
        一对颜色。第一个用于低于阈值的值,第二个用于高于阈值的值。可选。
    threshold
        根据该值应用 textcolors 中的颜色的数据单位值。如果为 None(默认),则使用颜色映射的中间值作为分隔。可选。
    **kwargs
        所有其他参数都将转发给每次用于创建文本标签的 `text` 调用。
    """

    if not isinstance(data, (list, np.ndarray)):
        data = im.get_array()

    ## 将阈值归一化到图像的颜色范围。
    if threshold is not None:
        threshold = im.norm(threshold)
    else:
        threshold = im.norm(data.max())/2.

    ## 将默认对齐方式设置为居中,但允许通过 textkw 覆盖。
    kw = dict(horizontalalignment="center", verticalalignment="center")
    kw.update(textkw)

    ## 如果提供的是字符串,则获取格式化器
    if isinstance(valfmt, str):
        valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

    ## 遍历数据并为每个“像素”创建一个 `Text`。
    ## 根据数据更改文本的颜色。
    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
            text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
            texts.append(text)

    return texts

应用该函数

既然我们已经有了这些函数,就可以用它们来创建一个带注释的热图。我们创建一组新的数据,为 imshow 提供更多参数,在注释上使用整数格式,并提供一些颜色。我们还通过使用 matplotlib.ticker.FuncFormatter 隐藏对角线元素(这些元素都是 1)。

data = np.random.randint(2, 100, size=(7, 7))
y = [f"Book {i}" for i in range(1, 8)]
x = [f"Store {i}" for i in list("ABCDEFG")]

fig, ax = plt.subplots()
im, _ = heatmap(data, y, x, ax=ax, vmin=0, cmap="magma_r", cbarlabel="weekly sold copies")
annotate_heatmap(im, valfmt="{x:d}", size=7, threshold=20, textcolors=("red", "white"))

def func(x, pos):
    return f"{x:.2f}".replace("0.", ".").replace("1.00", "")

annotate_heatmap(im, valfmt=matplotlib.ticker.FuncFormatter(func), size=7)

更复杂的热图示例

在以下内容中,我们通过在不同情况下应用之前创建的函数并使用不同的参数,展示这些函数的多功能性。

np.random.seed(19680801)

fig, ((ax, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6))

## 用不同的字体大小和颜色映射复制上述示例。

im, _ = heatmap(harvest, vegetables, farmers, ax=ax, cmap="Wistia", cbarlabel="harvest [t/year]")
annotate_heatmap(im, valfmt="{x:.1f}", size=7)

## 有时甚至数据本身就是分类的。在这里,我们使用 `matplotlib.colors.BoundaryNorm` 将数据分类,并使用此方法为绘图上色,同时还从一个类别数组中获取类别标签。

data = np.random.randn(6, 6)
y = [f"Prod. {i}" for i in range(10, 70, 10)]
x = [f"Cycle {i}" for i in range(1, 7)]

qrates = list("ABCDEFG")
norm = matplotlib.colors.BoundaryNorm(np.linspace(-3.5, 3.5, 8), 7)
fmt = matplotlib.ticker.FuncFormatter(lambda x, pos: qrates[::-1][norm(x)])

im, _ = heatmap(data, y, x, ax=ax3, cmap=mpl.colormaps["PiYG"].resampled(7), norm=norm, cbar_kw=dict(ticks=np.arange(-3, 4), format=fmt), cbarlabel="Quality Rating")
annotate_heatmap(im, valfmt=fmt, size=9, fontweight="bold", threshold=-1, textcolors=("red", "black"))

## 我们可以很好地绘制相关矩阵。由于其取值范围在 -1 到 1 之间,我们将其用作 vmin 和 vmax。

corr_matrix = np.corrcoef(harvest)
im, _ = heatmap(corr_matrix, vegetables, vegetables, ax=ax4, cmap="PuOr", vmin=-1, vmax=1, cbarlabel="correlation coeff.")
annotate_heatmap(im, valfmt=matplotlib.ticker.FuncFormatter(func), size=7)

plt.tight_layout()
plt.show()

总结

在本实验中,我们学习了如何使用 Matplotlib 的 imshow 函数在 Python 中创建带注释的热图。我们首先创建了一个简单的分类热图,然后将其扩展为一个可重复使用的函数。最后,我们使用不同的参数探索了一些更复杂的热图示例。