Comparaison entre la Recherche en grille et la Décimation successive

Beginner

This tutorial is from open-source community. Access the source code

Introduction

Ce laboratoire compare deux algorithmes de recherche de paramètres populaires en apprentissage automatique : la Recherche en grille et la Décimation successive. Nous utiliserons la bibliothèque scikit-learn en Python pour effectuer la comparaison. Ces algorithmes sont utilisés pour trouver les meilleurs hyperparamètres pour un modèle d'apprentissage automatique donné.

Conseils sur la machine virtuelle

Une fois le démarrage de la machine virtuelle terminé, cliquez dans le coin supérieur gauche pour basculer vers l'onglet Carnet de notes pour accéder au carnet Jupyter pour pratiquer.

Parfois, vous devrez peut-être attendre quelques secondes pour que le carnet Jupyter ait fini de charger. La validation des opérations ne peut pas être automatisée en raison des limitations du carnet Jupyter.

Si vous rencontrez des problèmes pendant l'apprentissage, n'hésitez pas à demander à Labby. Donnez des commentaires après la session, et nous résoudrons rapidement le problème pour vous.

Importez les bibliothèques et le jeu de données nécessaires

Nous commençons par importer les bibliothèques et le jeu de données nécessaires pour ce laboratoire. Nous utiliserons la bibliothèque scikit-learn pour générer un jeu de données synthétique et effectuer la recherche de paramètres.

from time import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.svm import SVC
from sklearn import datasets
from sklearn.model_selection import GridSearchCV
from sklearn.experimental import enable_halving_search_cv
from sklearn.model_selection import HalvingGridSearchCV

rng = np.random.RandomState(0)
X, y = datasets.make_classification(n_samples=1000, random_state=rng)

gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
Cs = [1, 10, 100, 1e3, 1e4, 1e5]
param_grid = {"gamma": gammas, "C": Cs}

clf = SVC(random_state=rng)

Effectuez une Recherche en grille

Nous utiliserons la Recherche en grille pour effectuer une recherche de paramètres sur le modèle SVC. Nous utiliserons le jeu de données synthétique généré et la grille de paramètres générée dans l'Étape 1.

tic = time()
gs = GridSearchCV(estimator=clf, param_grid=param_grid)
gs.fit(X, y)
gs_time = time() - tic

Effectuez la Décimation successive

Nous allons maintenant effectuer une recherche de paramètres en utilisant la Décimation successive sur le même modèle SVC et le même jeu de données utilisés dans l'Étape 2.

tic = time()
gsh = HalvingGridSearchCV(
    estimator=clf, param_grid=param_grid, factor=2, random_state=rng
)
gsh.fit(X, y)
gsh_time = time() - tic

Visualisez les résultats

Nous allons maintenant visualiser les résultats des algorithmes de recherche de paramètres à l'aide de cartes thermiques. Les cartes thermiques montrent la note de test moyenne des combinaisons de paramètres pour l'instance SVC. La carte thermique de la Décimation successive montre également l'itération à laquelle les combinaisons ont été utilisées pour la dernière fois.

def make_heatmap(ax, gs, is_sh=False, make_cbar=False):
    """Helper to make a heatmap."""
    results = pd.DataFrame(gs.cv_results_)
    results[["param_C", "param_gamma"]] = results[["param_C", "param_gamma"]].astype(
        np.float64
    )
    if is_sh:
        ## SH dataframe: get mean_test_score values for the highest iter
        scores_matrix = results.sort_values("iter").pivot_table(
            index="param_gamma",
            columns="param_C",
            values="mean_test_score",
            aggfunc="last",
        )
    else:
        scores_matrix = results.pivot(
            index="param_gamma", columns="param_C", values="mean_test_score"
        )

    im = ax.imshow(scores_matrix)

    ax.set_xticks(np.arange(len(Cs)))
    ax.set_xticklabels(["{:.0E}".format(x) for x in Cs])
    ax.set_xlabel("C", fontsize=15)

    ax.set_yticks(np.arange(len(gammas)))
    ax.set_yticklabels(["{:.0E}".format(x) for x in gammas])
    ax.set_ylabel("gamma", fontsize=15)

    ## Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

    if is_sh:
        iterations = results.pivot_table(
            index="param_gamma", columns="param_C", values="iter", aggfunc="max"
        ).values
        for i in range(len(gammas)):
            for j in range(len(Cs)):
                ax.text(
                    j,
                    i,
                    iterations[i, j],
                    ha="center",
                    va="center",
                    color="w",
                    fontsize=20,
                )

    if make_cbar:
        fig.subplots_adjust(right=0.8)
        cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
        fig.colorbar(im, cax=cbar_ax)
        cbar_ax.set_ylabel("mean_test_score", rotation=-90, va="bottom", fontsize=15)


fig, axes = plt.subplots(ncols=2, sharey=True)
ax1, ax2 = axes

make_heatmap(ax1, gsh, is_sh=True)
make_heatmap(ax2, gs, make_cbar=True)

ax1.set_title("Décimation successive\ntime = {:.3f}s".format(gsh_time), fontsize=15)
ax2.set_title("Recherche en grille\ntime = {:.3f}s".format(gs_time), fontsize=15)

plt.show()

Sommaire

Nous avons comparé deux algorithmes de recherche de paramètres populaires en apprentissage automatique : la Recherche en grille et la Décimation successive. Nous avons utilisé la bibliothèque scikit-learn en Python pour effectuer la comparaison. Nous avons généré un jeu de données synthétique et effectué une recherche de paramètres sur le modèle SVC à l'aide des deux algorithmes. Nous avons ensuite visualisé les résultats à l'aide de cartes thermiques. Nous avons constaté que l'algorithme de Décimation successive était capable de trouver des combinaisons de paramètres tout aussi précises que la Recherche en grille, en beaucoup moins de temps.