Einführung
Beim Datenanalyse möchten wir oft Daten, die von zwei unabhängigen Variablen abhängen, als farbkodiertes Bilddiagramm, das als Heatmap bekannt ist, darstellen. In diesem Lab verwenden wir die imshow-Funktion von Matplotlib, um eine Heatmap mit Anmerkungen zu erstellen. Wir beginnen mit einem einfachen Beispiel und erweitern es, sodass es als universelle Funktion verwendbar ist.
VM-Tipps
Nachdem der VM-Start abgeschlossen ist, klicken Sie in der oberen linken Ecke, um zur Registerkarte Notebook zu wechseln und Jupyter Notebook für die Übung zu öffnen.
Manchmal müssen Sie einige Sekunden warten, bis Jupyter Notebook vollständig geladen ist. Die Validierung von Vorgängen kann aufgrund von Einschränkungen in Jupyter Notebook nicht automatisiert werden.
Wenn Sie bei der Lernphase Probleme haben, können Sie Labby gerne fragen. Geben Sie nach der Sitzung Feedback, und wir werden das Problem für Sie prompt beheben.
Einfache kategorische Heatmap
Wir beginnen mit der Definition einiger Daten. Wir benötigen eine 2D-Liste oder -Array, das die zu farbkodierenden Daten definiert. Anschließend benötigen wir auch zwei Listen oder Arrays von Kategorien. Die Heatmap selbst ist ein imshow-Diagramm mit den Labels auf die Kategorien gesetzt. Wir werden die text-Funktion verwenden, um die Daten selbst zu beschriften, indem wir in jeder Zelle ein Text erstellen, das den Wert dieser Zelle zeigt.
import matplotlib.pyplot as plt
import numpy as np
gemüsesorten = ["Gurke", "Tomate", "Salat", "Spargel", "Kartoffel", "Weizen", "Gerste"]
bauer = ["Bauer Joe", "Upland Bros.", "Smith Gardening", "Agrifun", "Organiculture", "BioGoods Ltd.", "Cornylee Corp."]
ernte = np.array([[0.8, 2.4, 2.5, 3.9, 0.0, 4.0, 0.0],
[2.4, 0.0, 4.0, 1.0, 2.7, 0.0, 0.0],
[1.1, 2.4, 0.8, 4.3, 1.9, 4.4, 0.0],
[0.6, 0.0, 0.3, 0.0, 3.1, 0.0, 0.0],
[0.7, 1.7, 0.6, 2.6, 2.2, 6.2, 0.0],
[1.3, 1.2, 0.0, 0.0, 0.0, 3.2, 5.1],
[0.1, 2.0, 0.0, 1.4, 0.0, 1.9, 6.3]])
fig, ax = plt.subplots()
im = ax.imshow(ernte)
## Zeige alle Striche und beschrifte sie mit den jeweiligen Listeinträgen
ax.set_xticks(np.arange(len(bauer)), labels=bauer)
ax.set_yticks(np.arange(len(gemüsesorten)), labels=gemüsesorten)
## Drehe die Strichmarkenbeschriftungen und setze ihre Ausrichtung
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
## Schleife über die Daten Dimensionen und erstelle Textanmerkungen
for i in range(len(gemüsesorten)):
for j in range(len(bauer)):
text = ax.text(j, i, ernte[i, j], ha="center", va="center", color="w")
ax.set_title("Ernte der örtlichen Bauern (in Tonnen/ Jahr)")
fig.tight_layout()
plt.show()
Verwendung des Hilfsfunktions-Code-Stils
Wir werden eine Funktion erstellen, die die Daten sowie die Zeilen- und Spaltenbeschriftungen als Eingabe nimmt und Argumente zulässt, die zum Anpassen des Diagramms verwendet werden. Wir werden die umgebenden Achsenkanten deaktivieren und ein Gitter aus weißen Linien erstellen, um die Zellen voneinander zu trennen. Hier möchten wir auch eine Farbskala erstellen und die Beschriftungen über der Heatmap statt darunter positionieren. Die Anmerkungen sollen je nach Schwellenwert unterschiedliche Farben erhalten, um einen besseren Kontrast zu der Pixelfarbe zu erzielen.
def heatmap(data, row_labels, col_labels, ax=None, cbar_kw=None, cbarlabel="", **kwargs):
"""
Erstellt eine Heatmap aus einem numpy-Array und zwei Listen von Beschriftungen.
Parameter
----------
data
Ein 2D numpy-Array der Größe (M, N).
row_labels
Eine Liste oder ein Array der Länge M mit den Beschriftungen für die Zeilen.
col_labels
Eine Liste oder ein Array der Länge N mit den Beschriftungen für die Spalten.
ax
Eine `matplotlib.axes.Axes`-Instanz, auf die die Heatmap geplottet wird. Wenn nicht angegeben, wird die aktuelle Achse verwendet oder eine neue erstellt. Optional.
cbar_kw
Ein Wörterbuch mit Argumenten für `matplotlib.Figure.colorbar`. Optional.
cbarlabel
Die Bezeichnung für die Farbskala. Optional.
**kwargs
Alle anderen Argumente werden an `imshow` weitergeleitet.
"""
if ax is None:
ax = plt.gca()
if cbar_kw is None:
cbar_kw = {}
## Plotte die Heatmap
im = ax.imshow(data, **kwargs)
## Erstelle die Farbskala
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
## Zeige alle Striche und beschrifte sie mit den jeweiligen Listeinträgen.
ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
## Lasse die horizontale Achsenbeschriftung oben erscheinen.
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
## Drehe die Strichmarkenbeschriftungen und setze ihre Ausrichtung.
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")
## Deaktiviere die Kanten und erstelle ein weißes Gitter.
ax.spines[:].set_visible(False)
ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
ax.tick_params(which="minor", bottom=False, left=False)
return im, cbar
def annotate_heatmap(im, data=None, valfmt="{x:.2f}", textcolors=("black", "white"), threshold=None, **textkw):
"""
Eine Funktion, um eine Heatmap zu annotieren.
Parameter
----------
im
Das AxesImage, das beschriftet werden soll.
data
Die Daten, die zur Annotation verwendet werden. Wenn None, werden die Daten des Bilds verwendet. Optional.
valfmt
Das Format der Anmerkungen innerhalb der Heatmap. Dies sollte entweder die String-Format-Methode verwenden, z.B. "$ {x:.2f}", oder ein `matplotlib.ticker.Formatter` sein. Optional.
textcolors
Ein Paar von Farben. Die erste wird für Werte unterhalb einer Schwelle verwendet, die zweite für die oberhalb. Optional.
threshold
Wert in Daten-Einheiten, gemäß dem die Farben aus textcolors angewendet werden. Wenn None (Standard), wird die Mitte der Farbskala als Trennung verwendet. Optional.
**kwargs
Alle anderen Argumente werden an jeden Aufruf von `text` weitergeleitet, der verwendet wird, um die Textbeschriftungen zu erstellen.
"""
if not isinstance(data, (list, np.ndarray)):
data = im.get_array()
## Normalisiere die Schwelle auf den Farbbereich des Bilds.
if threshold is not None:
threshold = im.norm(threshold)
else:
threshold = im.norm(data.max())/2.
## Setze die Standardausrichtung auf zentriert, aber gestatten es, durch textkw überschrieben zu werden.
kw = dict(horizontalalignment="center", verticalalignment="center")
kw.update(textkw)
## Hole den Formatter, falls ein String übergeben wird
if isinstance(valfmt, str):
valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
## Schleife über die Daten und erstelle für jedes "Pixel" ein `Text`.
## Ändere die Farbe des Texts je nach Daten.
texts = []
for i in range(data.shape[0]):
for j in range(data.shape[1]):
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
texts.append(text)
return texts
Anwenden der Funktion
Jetzt, da wir die Funktionen haben, können wir sie verwenden, um eine Heatmap mit Anmerkungen zu erstellen. Wir erstellen einen neuen Datensatz, geben weitere Argumente an imshow, verwenden ein ganzzahliges Format für die Anmerkungen und geben einige Farben an. Wir verstecken auch die Diagonalelemente (die alle 1 sind), indem wir einen matplotlib.ticker.FuncFormatter verwenden.
data = np.random.randint(2, 100, size=(7, 7))
y = [f"Book {i}" for i in range(1, 8)]
x = [f"Store {i}" for i in list("ABCDEFG")]
fig, ax = plt.subplots()
im, _ = heatmap(data, y, x, ax=ax, vmin=0, cmap="magma_r", cbarlabel="weekly sold copies")
annotate_heatmap(im, valfmt="{x:d}", size=7, threshold=20, textcolors=("red", "white"))
def func(x, pos):
return f"{x:.2f}".replace("0.", ".").replace("1.00", "")
annotate_heatmap(im, valfmt=matplotlib.ticker.FuncFormatter(func), size=7)
Komplexere Heatmap-Beispiele
Im Folgenden zeigen wir die Vielseitigkeit der zuvor erstellten Funktionen, indem wir sie in verschiedenen Fällen anwenden und verschiedene Argumente verwenden.
np.random.seed(19680801)
fig, ((ax, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6))
## Wiederhole das obige Beispiel mit einer anderen Schriftgröße und Farbpalette.
im, _ = heatmap(ernte, gemüsesorten, bauer, ax=ax, cmap="Wistia", cbarlabel="Ernte [t/Jahr]")
annotate_heatmap(im, valfmt="{x:.1f}", size=7)
## Manchmal ist sogar die Daten selbst kategorisch. Hier verwenden wir eine `matplotlib.colors.BoundaryNorm`, um die Daten in Klassen zu bringen und diese zur Färbung des Diagramms zu verwenden, aber auch, um die Klassenbezeichnungen aus einem Array von Klassen zu erhalten.
data = np.random.randn(6, 6)
y = [f"Prod. {i}" for i in range(10, 70, 10)]
x = [f"Cycle {i}" for i in range(1, 7)]
qrates = list("ABCDEFG")
norm = matplotlib.colors.BoundaryNorm(np.linspace(-3.5, 3.5, 8), 7)
fmt = matplotlib.ticker.FuncFormatter(lambda x, pos: qrates[::-1][norm(x)])
im, _ = heatmap(data, y, x, ax=ax3, cmap=mpl.colormaps["PiYG"].resampled(7), norm=norm, cbar_kw=dict(ticks=np.arange(-3, 4), format=fmt), cbarlabel="Quality Rating")
annotate_heatmap(im, valfmt=fmt, size=9, fontweight="bold", threshold=-1, textcolors=("red", "black"))
## Wir können eine Korrelationsmatrix schön plotten. Da diese zwischen -1 und 1 begrenzt ist, verwenden wir diese als vmin und vmax.
corr_matrix = np.corrcoef(ernte)
im, _ = heatmap(corr_matrix, gemüsesorten, gemüsesorten, ax=ax4, cmap="PuOr", vmin=-1, vmax=1, cbarlabel="Korrelationskoeffizient")
annotate_heatmap(im, valfmt=matplotlib.ticker.FuncFormatter(func), size=7)
plt.tight_layout()
plt.show()
Zusammenfassung
In diesem Lab haben wir gelernt, wie man in Python annotierte Heatmaps mit der imshow-Funktion von Matplotlib erstellt. Wir haben begonnen, indem wir eine einfache kategorische Heatmap erstellt haben und sie dann zu einer wiederverwendbaren Funktion erweitert haben. Schließlich haben wir einige komplexere Heatmap-Beispiele mit unterschiedlichen Argumenten untersucht.