Erstellen annotierter Heatmaps

Beginner

This tutorial is from open-source community. Access the source code

Einführung

Beim Datenanalyse möchten wir oft Daten, die von zwei unabhängigen Variablen abhängen, als farbkodiertes Bilddiagramm, das als Heatmap bekannt ist, darstellen. In diesem Lab verwenden wir die imshow-Funktion von Matplotlib, um eine Heatmap mit Anmerkungen zu erstellen. Wir beginnen mit einem einfachen Beispiel und erweitern es, sodass es als universelle Funktion verwendbar ist.

VM-Tipps

Nachdem der VM-Start abgeschlossen ist, klicken Sie in der oberen linken Ecke, um zur Registerkarte Notebook zu wechseln und Jupyter Notebook für die Übung zu öffnen.

Manchmal müssen Sie einige Sekunden warten, bis Jupyter Notebook vollständig geladen ist. Die Validierung von Vorgängen kann aufgrund von Einschränkungen in Jupyter Notebook nicht automatisiert werden.

Wenn Sie bei der Lernphase Probleme haben, können Sie Labby gerne fragen. Geben Sie nach der Sitzung Feedback, und wir werden das Problem für Sie prompt beheben.

Einfache kategorische Heatmap

Wir beginnen mit der Definition einiger Daten. Wir benötigen eine 2D-Liste oder -Array, das die zu farbkodierenden Daten definiert. Anschließend benötigen wir auch zwei Listen oder Arrays von Kategorien. Die Heatmap selbst ist ein imshow-Diagramm mit den Labels auf die Kategorien gesetzt. Wir werden die text-Funktion verwenden, um die Daten selbst zu beschriften, indem wir in jeder Zelle ein Text erstellen, das den Wert dieser Zelle zeigt.

import matplotlib.pyplot as plt
import numpy as np

gemüsesorten = ["Gurke", "Tomate", "Salat", "Spargel", "Kartoffel", "Weizen", "Gerste"]
bauer = ["Bauer Joe", "Upland Bros.", "Smith Gardening", "Agrifun", "Organiculture", "BioGoods Ltd.", "Cornylee Corp."]

ernte = np.array([[0.8, 2.4, 2.5, 3.9, 0.0, 4.0, 0.0],
                  [2.4, 0.0, 4.0, 1.0, 2.7, 0.0, 0.0],
                  [1.1, 2.4, 0.8, 4.3, 1.9, 4.4, 0.0],
                  [0.6, 0.0, 0.3, 0.0, 3.1, 0.0, 0.0],
                  [0.7, 1.7, 0.6, 2.6, 2.2, 6.2, 0.0],
                  [1.3, 1.2, 0.0, 0.0, 0.0, 3.2, 5.1],
                  [0.1, 2.0, 0.0, 1.4, 0.0, 1.9, 6.3]])

fig, ax = plt.subplots()
im = ax.imshow(ernte)

## Zeige alle Striche und beschrifte sie mit den jeweiligen Listeinträgen
ax.set_xticks(np.arange(len(bauer)), labels=bauer)
ax.set_yticks(np.arange(len(gemüsesorten)), labels=gemüsesorten)

## Drehe die Strichmarkenbeschriftungen und setze ihre Ausrichtung
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

## Schleife über die Daten Dimensionen und erstelle Textanmerkungen
for i in range(len(gemüsesorten)):
    for j in range(len(bauer)):
        text = ax.text(j, i, ernte[i, j], ha="center", va="center", color="w")

ax.set_title("Ernte der örtlichen Bauern (in Tonnen/ Jahr)")
fig.tight_layout()
plt.show()

Verwendung des Hilfsfunktions-Code-Stils

Wir werden eine Funktion erstellen, die die Daten sowie die Zeilen- und Spaltenbeschriftungen als Eingabe nimmt und Argumente zulässt, die zum Anpassen des Diagramms verwendet werden. Wir werden die umgebenden Achsenkanten deaktivieren und ein Gitter aus weißen Linien erstellen, um die Zellen voneinander zu trennen. Hier möchten wir auch eine Farbskala erstellen und die Beschriftungen über der Heatmap statt darunter positionieren. Die Anmerkungen sollen je nach Schwellenwert unterschiedliche Farben erhalten, um einen besseren Kontrast zu der Pixelfarbe zu erzielen.

def heatmap(data, row_labels, col_labels, ax=None, cbar_kw=None, cbarlabel="", **kwargs):
    """
    Erstellt eine Heatmap aus einem numpy-Array und zwei Listen von Beschriftungen.

    Parameter
    ----------
    data
        Ein 2D numpy-Array der Größe (M, N).
    row_labels
        Eine Liste oder ein Array der Länge M mit den Beschriftungen für die Zeilen.
    col_labels
        Eine Liste oder ein Array der Länge N mit den Beschriftungen für die Spalten.
    ax
        Eine `matplotlib.axes.Axes`-Instanz, auf die die Heatmap geplottet wird. Wenn nicht angegeben, wird die aktuelle Achse verwendet oder eine neue erstellt. Optional.
    cbar_kw
        Ein Wörterbuch mit Argumenten für `matplotlib.Figure.colorbar`. Optional.
    cbarlabel
        Die Bezeichnung für die Farbskala. Optional.
    **kwargs
        Alle anderen Argumente werden an `imshow` weitergeleitet.
    """

    if ax is None:
        ax = plt.gca()

    if cbar_kw is None:
        cbar_kw = {}

    ## Plotte die Heatmap
    im = ax.imshow(data, **kwargs)

    ## Erstelle die Farbskala
    cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")

    ## Zeige alle Striche und beschrifte sie mit den jeweiligen Listeinträgen.
    ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
    ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)

    ## Lasse die horizontale Achsenbeschriftung oben erscheinen.
    ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)

    ## Drehe die Strichmarkenbeschriftungen und setze ihre Ausrichtung.
    plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")

    ## Deaktiviere die Kanten und erstelle ein weißes Gitter.
    ax.spines[:].set_visible(False)
    ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
    ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
    ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)

    return im, cbar


def annotate_heatmap(im, data=None, valfmt="{x:.2f}", textcolors=("black", "white"), threshold=None, **textkw):
    """
    Eine Funktion, um eine Heatmap zu annotieren.

    Parameter
    ----------
    im
        Das AxesImage, das beschriftet werden soll.
    data
        Die Daten, die zur Annotation verwendet werden. Wenn None, werden die Daten des Bilds verwendet. Optional.
    valfmt
        Das Format der Anmerkungen innerhalb der Heatmap. Dies sollte entweder die String-Format-Methode verwenden, z.B. "$ {x:.2f}", oder ein `matplotlib.ticker.Formatter` sein. Optional.
    textcolors
        Ein Paar von Farben. Die erste wird für Werte unterhalb einer Schwelle verwendet, die zweite für die oberhalb. Optional.
    threshold
        Wert in Daten-Einheiten, gemäß dem die Farben aus textcolors angewendet werden. Wenn None (Standard), wird die Mitte der Farbskala als Trennung verwendet. Optional.
    **kwargs
        Alle anderen Argumente werden an jeden Aufruf von `text` weitergeleitet, der verwendet wird, um die Textbeschriftungen zu erstellen.
    """

    if not isinstance(data, (list, np.ndarray)):
        data = im.get_array()

    ## Normalisiere die Schwelle auf den Farbbereich des Bilds.
    if threshold is not None:
        threshold = im.norm(threshold)
    else:
        threshold = im.norm(data.max())/2.

    ## Setze die Standardausrichtung auf zentriert, aber gestatten es, durch textkw überschrieben zu werden.
    kw = dict(horizontalalignment="center", verticalalignment="center")
    kw.update(textkw)

    ## Hole den Formatter, falls ein String übergeben wird
    if isinstance(valfmt, str):
        valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

    ## Schleife über die Daten und erstelle für jedes "Pixel" ein `Text`.
    ## Ändere die Farbe des Texts je nach Daten.
    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
            text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
            texts.append(text)

    return texts

Anwenden der Funktion

Jetzt, da wir die Funktionen haben, können wir sie verwenden, um eine Heatmap mit Anmerkungen zu erstellen. Wir erstellen einen neuen Datensatz, geben weitere Argumente an imshow, verwenden ein ganzzahliges Format für die Anmerkungen und geben einige Farben an. Wir verstecken auch die Diagonalelemente (die alle 1 sind), indem wir einen matplotlib.ticker.FuncFormatter verwenden.

data = np.random.randint(2, 100, size=(7, 7))
y = [f"Book {i}" for i in range(1, 8)]
x = [f"Store {i}" for i in list("ABCDEFG")]

fig, ax = plt.subplots()
im, _ = heatmap(data, y, x, ax=ax, vmin=0, cmap="magma_r", cbarlabel="weekly sold copies")
annotate_heatmap(im, valfmt="{x:d}", size=7, threshold=20, textcolors=("red", "white"))

def func(x, pos):
    return f"{x:.2f}".replace("0.", ".").replace("1.00", "")

annotate_heatmap(im, valfmt=matplotlib.ticker.FuncFormatter(func), size=7)

Komplexere Heatmap-Beispiele

Im Folgenden zeigen wir die Vielseitigkeit der zuvor erstellten Funktionen, indem wir sie in verschiedenen Fällen anwenden und verschiedene Argumente verwenden.

np.random.seed(19680801)

fig, ((ax, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6))

## Wiederhole das obige Beispiel mit einer anderen Schriftgröße und Farbpalette.

im, _ = heatmap(ernte, gemüsesorten, bauer, ax=ax, cmap="Wistia", cbarlabel="Ernte [t/Jahr]")
annotate_heatmap(im, valfmt="{x:.1f}", size=7)

## Manchmal ist sogar die Daten selbst kategorisch. Hier verwenden wir eine `matplotlib.colors.BoundaryNorm`, um die Daten in Klassen zu bringen und diese zur Färbung des Diagramms zu verwenden, aber auch, um die Klassenbezeichnungen aus einem Array von Klassen zu erhalten.

data = np.random.randn(6, 6)
y = [f"Prod. {i}" for i in range(10, 70, 10)]
x = [f"Cycle {i}" for i in range(1, 7)]

qrates = list("ABCDEFG")
norm = matplotlib.colors.BoundaryNorm(np.linspace(-3.5, 3.5, 8), 7)
fmt = matplotlib.ticker.FuncFormatter(lambda x, pos: qrates[::-1][norm(x)])

im, _ = heatmap(data, y, x, ax=ax3, cmap=mpl.colormaps["PiYG"].resampled(7), norm=norm, cbar_kw=dict(ticks=np.arange(-3, 4), format=fmt), cbarlabel="Quality Rating")
annotate_heatmap(im, valfmt=fmt, size=9, fontweight="bold", threshold=-1, textcolors=("red", "black"))

## Wir können eine Korrelationsmatrix schön plotten. Da diese zwischen -1 und 1 begrenzt ist, verwenden wir diese als vmin und vmax.

corr_matrix = np.corrcoef(ernte)
im, _ = heatmap(corr_matrix, gemüsesorten, gemüsesorten, ax=ax4, cmap="PuOr", vmin=-1, vmax=1, cbarlabel="Korrelationskoeffizient")
annotate_heatmap(im, valfmt=matplotlib.ticker.FuncFormatter(func), size=7)

plt.tight_layout()
plt.show()

Zusammenfassung

In diesem Lab haben wir gelernt, wie man in Python annotierte Heatmaps mit der imshow-Funktion von Matplotlib erstellt. Wir haben begonnen, indem wir eine einfache kategorische Heatmap erstellt haben und sie dann zu einer wiederverwendbaren Funktion erweitert haben. Schließlich haben wir einige komplexere Heatmap-Beispiele mit unterschiedlichen Argumenten untersucht.