Einführung
In diesem Tutorial lernen wir die Funktion GridSearchCV in Scikit-Learn kennen. GridSearchCV ist eine Funktion, die verwendet wird, um die besten Hyperparameter für ein gegebenes Modell zu suchen. Es ist eine erschöpfende Suche über die angegebenen Parameterwerte für einen Schätzer. Die Parameter des Schätzers, der verwendet wird, um diese Methoden anzuwenden, werden durch Kreuzvalidierung über ein Parametergitter optimiert.
Tipps für die virtuelle Maschine
Nachdem der Start der virtuellen Maschine abgeschlossen ist, klicken Sie in der oberen linken Ecke, um zur Registerkarte Notebook zu wechseln und Jupyter Notebook für die Übung zu nutzen.
Manchmal müssen Sie einige Sekunden warten, bis Jupyter Notebook vollständig geladen ist. Die Validierung von Vorgängen kann aufgrund der Einschränkungen in Jupyter Notebook nicht automatisiert werden.
Wenn Sie bei der Lernphase Probleme haben, können Sie Labby gerne fragen. Geben Sie nach der Sitzung Feedback, und wir werden das Problem für Sie prompt beheben.
Bibliotheken importieren
Lassen Sie uns beginnen, indem wir die erforderlichen Bibliotheken importieren.
import numpy as np
from matplotlib import pyplot as plt
from sklearn.datasets import make_hastie_10_2
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
Datensatz laden
In diesem Schritt laden wir einen Datensatz mit der Funktion make_hastie_10_2 von Scikit-Learn. Diese Funktion generiert einen synthetischen Datensatz für die binäre Klassifikation.
X, y = make_hastie_10_2(n_samples=8000, random_state=42)
Hyperparameter und Evaluationsmetriken definieren
In diesem Schritt definieren wir die Hyperparameter für das DecisionTreeClassifier-Modell und die Evaluationsmetriken, die wir verwenden werden. Wir werden die AUC (Area Under the Curve) und die Genauigkeitsmetrik verwenden.
scoring = {"AUC": "roc_auc", "Accuracy": make_scorer(accuracy_score)}
Das Gitter-Search ausführen
In diesem Schritt verwenden wir die GridSearchCV-Funktion, um das Gitter-Search durchzuführen. Wir suchen die besten Hyperparameter für den min_samples_split-Parameter des DecisionTreeClassifier-Modells.
gs = GridSearchCV(
DecisionTreeClassifier(random_state=42),
param_grid={"min_samples_split": range(2, 403, 20)},
scoring=scoring,
refit="AUC",
n_jobs=2,
return_train_score=True,
)
gs.fit(X, y)
results = gs.cv_results_
Die Ergebnisse visualisieren
In diesem Schritt werden wir die Ergebnisse der Gitter-Suche mit einem Diagramm visualisieren. Wir werden die AUC- und Genauigkeitsscores für sowohl den Trainings- als auch den Testdatensatz plotten.
plt.figure(figsize=(13, 13))
plt.title("GridSearchCV evaluating using multiple scorers simultaneously", fontsize=16)
plt.xlabel("min_samples_split")
plt.ylabel("Score")
ax = plt.gca()
ax.set_xlim(0, 402)
ax.set_ylim(0.73, 1)
## Get the regular numpy array from the MaskedArray
X_axis = np.array(results["param_min_samples_split"].data, dtype=float)
for scorer, color in zip(sorted(scoring), ["g", "k"]):
for sample, style in (("train", "--"), ("test", "-")):
sample_score_mean = results["mean_%s_%s" % (sample, scorer)]
sample_score_std = results["std_%s_%s" % (sample, scorer)]
ax.fill_between(
X_axis,
sample_score_mean - sample_score_std,
sample_score_mean + sample_score_std,
alpha=0.1 if sample == "test" else 0,
color=color,
)
ax.plot(
X_axis,
sample_score_mean,
style,
color=color,
alpha=1 if sample == "test" else 0.7,
label="%s (%s)" % (scorer, sample),
)
best_index = np.nonzero(results["rank_test_%s" % scorer] == 1)[0][0]
best_score = results["mean_test_%s" % scorer][best_index]
## Plot a dotted vertical line at the best score for that scorer marked by x
ax.plot(
[
X_axis[best_index],
]
* 2,
[0, best_score],
linestyle="-.",
color=color,
marker="x",
markeredgewidth=3,
ms=8,
)
## Annotate the best score for that scorer
ax.annotate("%0.2f" % best_score, (X_axis[best_index], best_score + 0.005))
plt.legend(loc="best")
plt.grid(False)
plt.show()
Zusammenfassung
In diesem Tutorial haben wir uns mit der GridSearchCV-Funktion in Scikit-Learn beschäftigt. Wir haben gesehen, wie man einen Datensatz lädt, Hyperparameter und Evaluationsmetriken definiert, das Gitter-Search ausführt und die Ergebnisse mit einem Diagramm visualisiert. GridSearchCV ist eine wichtige Funktion, um die besten Hyperparameter für ein gegebenes Modell zu suchen.