绘制网格搜索数字图

Machine LearningMachine LearningBeginner
立即练习

This tutorial is from open-source community. Access the source code

💡 本教程由 AI 辅助翻译自英文原版。如需查看原文,您可以 切换至英文原版

简介

本实验展示了如何使用 scikit-learn 库通过交叉验证进行超参数调优。目的是使用二元分类对手写数字图像进行分类,以便于理解:识别一个数字是否为 8。所使用的数据集是数字数据集。然后,在模型选择步骤中未使用的专用评估集上测量所选超参数和训练模型的性能。

虚拟机提示

虚拟机启动完成后,点击左上角切换到“笔记本”标签,以访问 Jupyter Notebook 进行练习。

有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。

如果你在学习过程中遇到问题,请随时向 Labby 提问。课程结束后提供反馈,我们将立即为你解决问题。


Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL sklearn(("Sklearn")) -.-> sklearn/CoreModelsandAlgorithmsGroup(["Core Models and Algorithms"]) sklearn(("Sklearn")) -.-> sklearn/ModelSelectionandEvaluationGroup(["Model Selection and Evaluation"]) ml(("Machine Learning")) -.-> ml/FrameworkandSoftwareGroup(["Framework and Software"]) sklearn/CoreModelsandAlgorithmsGroup -.-> sklearn/svm("Support Vector Machines") sklearn/ModelSelectionandEvaluationGroup -.-> sklearn/model_selection("Model Selection") sklearn/ModelSelectionandEvaluationGroup -.-> sklearn/metrics("Metrics") ml/FrameworkandSoftwareGroup -.-> ml/sklearn("scikit-learn") subgraph Lab Skills sklearn/svm -.-> lab-49155{{"绘制网格搜索数字图"}} sklearn/model_selection -.-> lab-49155{{"绘制网格搜索数字图"}} sklearn/metrics -.-> lab-49155{{"绘制网格搜索数字图"}} ml/sklearn -.-> lab-49155{{"绘制网格搜索数字图"}} end

加载数据

我们将加载数字数据集并将图像展平为向量。每个 8x8 像素的图像都需要转换为一个 64 像素的向量。因此,我们将得到一个形状为 (n_images, n_pixels) 的最终数据数组。我们还将把数据分成大小相等的训练集和测试集。

from sklearn import datasets
from sklearn.model_selection import train_test_split

digits = datasets.load_digits()

n_samples = len(digits.images)
X = digits.images.reshape((n_samples, -1))
y = digits.target == 8

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)

定义网格搜索策略

我们将定义一个函数,该函数将被传递给 GridSearchCV 实例的 refit 参数。它将实现自定义策略,以便从 GridSearchCVcv_results_ 属性中选择最佳候选模型。一旦选择了候选模型,GridSearchCV 实例会自动对其进行重新拟合。

这里的策略是筛选出在精确率和召回率方面表现最佳的模型。从选定的模型中,我们最终选择预测速度最快的模型。请注意,这些自定义选择完全是任意的。

import pandas as pd
from sklearn.metrics import classification_report

def print_dataframe(filtered_cv_results):
    """Pretty print for filtered dataframe"""
    for mean_precision, std_precision, mean_recall, std_recall, params in zip(
        filtered_cv_results["mean_test_precision"],
        filtered_cv_results["std_test_precision"],
        filtered_cv_results["mean_test_recall"],
        filtered_cv_results["std_test_recall"],
        filtered_cv_results["params"],
    ):
        print(
            f"precision: {mean_precision:0.3f} (±{std_precision:0.03f}),"
            f" recall: {mean_recall:0.3f} (±{std_recall:0.03f}),"
            f" for {params}"
        )
    print()


def refit_strategy(cv_results):
    """Define the strategy to select the best estimator.

    The strategy defined here is to filter-out all results below a precision threshold
    of 0.98, rank the remaining by recall and keep all models with one standard
    deviation of the best by recall. Once these models are selected, we can select the
    fastest model to predict.

    Parameters
    ----------
    cv_results : dict of numpy (masked) ndarrays
        CV results as returned by the `GridSearchCV`.

    Returns
    -------
    best_index : int
        The index of the best estimator as it appears in `cv_results`.
    """
    ## print the info about the grid-search for the different scores
    precision_threshold = 0.98

    cv_results_ = pd.DataFrame(cv_results)
    print("All grid-search results:")
    print_dataframe(cv_results_)

    ## Filter-out all results below the threshold
    high_precision_cv_results = cv_results_[
        cv_results_["mean_test_precision"] > precision_threshold
    ]

    print(f"Models with a precision higher than {precision_threshold}:")
    print_dataframe(high_precision_cv_results)

    high_precision_cv_results = high_precision_cv_results[
        [
            "mean_score_time",
            "mean_test_recall",
            "std_test_recall",
            "mean_test_precision",
            "std_test_precision",
            "rank_test_recall",
            "rank_test_precision",
            "params",
        ]
    ]

    ## Select the most performant models in terms of recall
    ## (within 1 sigma from the best)
    best_recall_std = high_precision_cv_results["mean_test_recall"].std()
    best_recall = high_precision_cv_results["mean_test_recall"].max()
    best_recall_threshold = best_recall - best_recall_std

    high_recall_cv_results = high_precision_cv_results[
        high_precision_cv_results["mean_test_recall"] > best_recall_threshold
    ]
    print(
        "Out of the previously selected high precision models, we keep all the\n"
        "the models within one standard deviation of the highest recall model:"
    )
    print_dataframe(high_recall_cv_results)

    ## From the best candidates, select the fastest model to predict
    fastest_top_recall_high_precision_index = high_recall_cv_results[
        "mean_score_time"
    ].idxmin()

    print(
        "\nThe selected final model is the fastest to predict out of the previously\n"
        "selected subset of best models based on precision and recall.\n"
        "Its scoring time is:\n\n"
        f"{high_recall_cv_results.loc[fastest_top_recall_high_precision_index]}"
    )

    return fastest_top_recall_high_precision_index

定义超参数

我们将定义超参数并创建 GridSearchCV 实例。

from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC

tuned_parameters = [
    {"kernel": ["rbf"], "gamma": [1e-3, 1e-4], "C": [1, 10, 100, 1000]},
    {"kernel": ["linear"], "C": [1, 10, 100, 1000]},
]

grid_search = GridSearchCV(
    SVC(), tuned_parameters, scoring=["precision", "recall"], refit=refit_strategy
)

拟合模型并进行预测

我们将拟合模型并在评估集上进行预测。

grid_search.fit(X_train, y_train)

## 网格搜索使用我们的自定义策略选择的参数是:
grid_search.best_params_

## 最后,我们在留出的评估集上评估微调后的模型:`grid_search` 对象 **已使用我们的自定义重新拟合策略选择的参数在完整训练集上自动重新拟合**。
y_pred = grid_search.predict(X_test)
print(classification_report(y_test, y_pred))

总结

在本实验中,我们学习了如何使用 scikit-learn 库通过交叉验证来进行超参数调整。我们使用了数字数据集,并定义了一种自定义的重新拟合策略,以便从 GridSearchCV 实例的 cv_results_ 属性中选择最佳候选模型。最后,我们在留出的评估集上评估了微调后的模型。