使用 GridSearchCV 优化模型超参数

Machine LearningMachine LearningBeginner
立即练习

This tutorial is from open-source community. Access the source code

💡 本教程由 AI 辅助翻译自英文原版。如需查看原文,您可以 切换至英文原版

简介

在本教程中,我们将学习 Scikit-Learn 中的 GridSearchCV 函数。GridSearchCV 是一个用于为给定模型搜索最佳超参数的函数。它会在估计器的指定参数值上进行详尽搜索。用于应用这些方法的估计器参数通过在参数网格上进行交叉验证网格搜索来优化。

虚拟机使用提示

虚拟机启动完成后,点击左上角切换到“笔记本”标签页,以访问 Jupyter Notebook 进行练习。

有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。

如果你在学习过程中遇到问题,随时向 Labby 提问。课程结束后提供反馈,我们会及时为你解决问题。


Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL ml(("Machine Learning")) -.-> ml/FrameworkandSoftwareGroup(["Framework and Software"]) sklearn(("Sklearn")) -.-> sklearn/CoreModelsandAlgorithmsGroup(["Core Models and Algorithms"]) sklearn(("Sklearn")) -.-> sklearn/ModelSelectionandEvaluationGroup(["Model Selection and Evaluation"]) sklearn(("Sklearn")) -.-> sklearn/UtilitiesandDatasetsGroup(["Utilities and Datasets"]) sklearn/CoreModelsandAlgorithmsGroup -.-> sklearn/tree("Decision Trees") sklearn/ModelSelectionandEvaluationGroup -.-> sklearn/model_selection("Model Selection") sklearn/ModelSelectionandEvaluationGroup -.-> sklearn/metrics("Metrics") sklearn/UtilitiesandDatasetsGroup -.-> sklearn/datasets("Datasets") ml/FrameworkandSoftwareGroup -.-> ml/sklearn("scikit-learn") subgraph Lab Skills sklearn/tree -.-> lab-49219{{"使用 GridSearchCV 优化模型超参数"}} sklearn/model_selection -.-> lab-49219{{"使用 GridSearchCV 优化模型超参数"}} sklearn/metrics -.-> lab-49219{{"使用 GridSearchCV 优化模型超参数"}} sklearn/datasets -.-> lab-49219{{"使用 GridSearchCV 优化模型超参数"}} ml/sklearn -.-> lab-49219{{"使用 GridSearchCV 优化模型超参数"}} end

导入库

让我们先导入必要的库。

import numpy as np
from matplotlib import pyplot as plt

from sklearn.datasets import make_hastie_10_2
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier

加载数据集

在这一步中,我们将使用 Scikit-Learn 的 make_hastie_10_2 函数加载一个数据集。此函数生成用于二元分类的合成数据集。

X, y = make_hastie_10_2(n_samples=8000, random_state=42)

定义超参数和评估指标

在这一步中,我们将为决策树分类器模型定义超参数以及我们将使用的评估指标。我们将使用 AUC(曲线下面积)和准确率指标。

scoring = {"AUC": "roc_auc", "Accuracy": make_scorer(accuracy_score)}

执行网格搜索

在这一步中,我们将使用 GridSearchCV 函数来执行网格搜索。我们将为决策树分类器模型的 min_samples_split 参数搜索最佳超参数。

gs = GridSearchCV(
    DecisionTreeClassifier(random_state=42),
    param_grid={"min_samples_split": range(2, 403, 20)},
    scoring=scoring,
    refit="AUC",
    n_jobs=2,
    return_train_score=True,
)
gs.fit(X, y)
results = gs.cv_results_

可视化结果

在这一步中,我们将使用图表来可视化网格搜索的结果。我们将绘制训练集和测试集的 AUC 和准确率分数。

plt.figure(figsize=(13, 13))
plt.title("GridSearchCV evaluating using multiple scorers simultaneously", fontsize=16)

plt.xlabel("min_samples_split")
plt.ylabel("Score")

ax = plt.gca()
ax.set_xlim(0, 402)
ax.set_ylim(0.73, 1)

## Get the regular numpy array from the MaskedArray
X_axis = np.array(results["param_min_samples_split"].data, dtype=float)

for scorer, color in zip(sorted(scoring), ["g", "k"]):
    for sample, style in (("train", "--"), ("test", "-")):
        sample_score_mean = results["mean_%s_%s" % (sample, scorer)]
        sample_score_std = results["std_%s_%s" % (sample, scorer)]
        ax.fill_between(
            X_axis,
            sample_score_mean - sample_score_std,
            sample_score_mean + sample_score_std,
            alpha=0.1 if sample == "test" else 0,
            color=color,
        )
        ax.plot(
            X_axis,
            sample_score_mean,
            style,
            color=color,
            alpha=1 if sample == "test" else 0.7,
            label="%s (%s)" % (scorer, sample),
        )

    best_index = np.nonzero(results["rank_test_%s" % scorer] == 1)[0][0]
    best_score = results["mean_test_%s" % scorer][best_index]

    ## Plot a dotted vertical line at the best score for that scorer marked by x
    ax.plot(
        [
            X_axis[best_index],
        ]
        * 2,
        [0, best_score],
        linestyle="-.",
        color=color,
        marker="x",
        markeredgewidth=3,
        ms=8,
    )

    ## Annotate the best score for that scorer
    ax.annotate("%0.2f" % best_score, (X_axis[best_index], best_score + 0.005))

plt.legend(loc="best")
plt.grid(False)
plt.show()

总结

在本教程中,我们学习了 Scikit-Learn 中的 GridSearchCV 函数。我们了解了如何加载数据集、定义超参数和评估指标、执行网格搜索以及使用图表可视化结果。GridSearchCV 是在为给定模型搜索最佳超参数时要使用的一个重要函数。