Kalibrierungskurven plotten
Wir trainieren jedes der vier Modelle mit dem kleinen Trainingsdataset und plotten Kalibrierungskurven mit den vorhergesagten Wahrscheinlichkeiten des Testdatasets. Kalibrierungskurven werden erstellt, indem die vorhergesagten Wahrscheinlichkeiten in Intervalle eingeteilt und dann die durchschnittliche vorhergesagte Wahrscheinlichkeit in jedem Intervall gegen die beobachtete Häufigkeit („Anteil positiver“) geplottet wird. Unter der Kalibrierungskurve plotten wir ein Histogramm, das die Verteilung der vorhergesagten Wahrscheinlichkeiten zeigt, oder genauer gesagt, die Anzahl der Proben in jedem Intervall der vorhergesagten Wahrscheinlichkeiten.
import numpy as np
from sklearn.svm import LinearSVC
from sklearn.calibration import CalibrationDisplay
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
## Create classifiers
lr = LogisticRegression()
gnb = GaussianNB()
svc = NaivelyCalibratedLinearSVC(C=1.0, dual="auto")
rfc = RandomForestClassifier()
clf_list = [
(lr, "Logistic"),
(gnb, "Naive Bayes"),
(svc, "SVC"),
(rfc, "Random forest"),
]
fig = plt.figure(figsize=(10, 10))
gs = GridSpec(4, 2)
colors = plt.get_cmap("Dark2")
ax_calibration_curve = fig.add_subplot(gs[:2, :2])
calibration_displays = {}
markers = ["^", "v", "s", "o"]
for i, (clf, name) in enumerate(clf_list):
clf.fit(X_train, y_train)
display = CalibrationDisplay.from_estimator(
clf,
X_test,
y_test,
n_bins=10,
name=name,
ax=ax_calibration_curve,
color=colors(i),
marker=markers[i],
)
calibration_displays[name] = display
ax_calibration_curve.grid()
ax_calibration_curve.set_title("Calibration plots")
## Add histogram
grid_positions = [(2, 0), (2, 1), (3, 0), (3, 1)]
for i, (_, name) in enumerate(clf_list):
row, col = grid_positions[i]
ax = fig.add_subplot(gs[row, col])
ax.hist(
calibration_displays[name].y_prob,
range=(0, 1),
bins=10,
label=name,
color=colors(i),
)
ax.set(title=name, xlabel="Mean predicted probability", ylabel="Count")
plt.tight_layout()
plt.show()