그리드 검색과 연속 반복 검색 비교

Beginner

This tutorial is from open-source community. Access the source code

소개

이 실험은 머신 러닝에서 두 가지 인기 있는 매개변수 탐색 알고리즘, 격자 검색 (Grid Search) 과 연속 반감법 (Successive Halving) 을 비교합니다. 파이썬의 scikit-learn 라이브러리를 사용하여 비교를 수행할 것입니다. 이러한 알고리즘은 주어진 머신 러닝 모델에 대한 최적의 하이퍼파라미터를 찾는 데 사용됩니다.

VM 팁

VM 시작이 완료되면 왼쪽 상단 모서리를 클릭하여 Notebook 탭으로 전환하여 연습을 위한 Jupyter Notebook에 접근할 수 있습니다.

때때로 Jupyter Notebook 이 완전히 로드되기까지 몇 초 정도 기다려야 할 수 있습니다. Jupyter Notebook 의 제한으로 인해 작업의 유효성 검사를 자동화할 수 없습니다.

학습 중 문제가 발생하면 Labby 에 문의하십시오. 세션 후 피드백을 제공하면 문제를 신속하게 해결해 드리겠습니다.

필요한 라이브러리 및 데이터셋 가져오기

먼저 이 실험에 필요한 라이브러리와 데이터셋을 가져옵니다. 합성 데이터셋을 생성하고 매개변수 검색을 수행하기 위해 scikit-learn 라이브러리를 사용할 것입니다.

from time import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.svm import SVC
from sklearn import datasets
from sklearn.model_selection import GridSearchCV
from sklearn.experimental import enable_halving_search_cv
from sklearn.model_selection import HalvingGridSearchCV

rng = np.random.RandomState(0)
X, y = datasets.make_classification(n_samples=1000, random_state=rng)

gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
Cs = [1, 10, 100, 1e3, 1e4, 1e5]
param_grid = {"gamma": gammas, "C": Cs}

clf = SVC(random_state=rng)

격자 검색 수행

SVC 모델에 대한 매개변수 검색을 위해 격자 검색을 사용할 것입니다. 1 단계에서 생성된 합성 데이터셋과 매개변수 그리드를 사용할 것입니다.

tic = time()
gs = GridSearchCV(estimator=clf, param_grid=param_grid)
gs.fit(X, y)
gs_time = time() - tic

연속 반복 검색 수행

이제 2 단계에서 사용된 동일한 SVC 모델과 데이터셋에 대해 연속 반복 검색 (Successive Halving) 을 사용하여 매개변수 검색을 수행할 것입니다.

tic = time()
gsh = HalvingGridSearchCV(
    estimator=clf, param_grid=param_grid, factor=2, random_state=rng
)
gsh.fit(X, y)
gsh_time = time() - tic

결과 시각화

이제 히트맵을 사용하여 매개변수 검색 알고리즘의 결과를 시각화할 것입니다. 히트맵은 SVC 인스턴스에 대한 매개변수 조합의 평균 테스트 점수를 보여줍니다. 연속 반복 검색 히트맵은 조합이 마지막으로 사용된 반복 횟수도 표시합니다.

def make_heatmap(ax, gs, is_sh=False, make_cbar=False):
    """Helper to make a heatmap."""
    results = pd.DataFrame(gs.cv_results_)
    results[["param_C", "param_gamma"]] = results[["param_C", "param_gamma"]].astype(
        np.float64
    )
    if is_sh:
        ## SH dataframe: get mean_test_score values for the highest iter
        scores_matrix = results.sort_values("iter").pivot_table(
            index="param_gamma",
            columns="param_C",
            values="mean_test_score",
            aggfunc="last",
        )
    else:
        scores_matrix = results.pivot(
            index="param_gamma", columns="param_C", values="mean_test_score"
        )

    im = ax.imshow(scores_matrix)

    ax.set_xticks(np.arange(len(Cs)))
    ax.set_xticklabels(["{:.0E}".format(x) for x in Cs])
    ax.set_xlabel("C", fontsize=15)

    ax.set_yticks(np.arange(len(gammas)))
    ax.set_yticklabels(["{:.0E}".format(x) for x in gammas])
    ax.set_ylabel("gamma", fontsize=15)

    ## Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

    if is_sh:
        iterations = results.pivot_table(
            index="param_gamma", columns="param_C", values="iter", aggfunc="max"
        ).values
        for i in range(len(gammas)):
            for j in range(len(Cs)):
                ax.text(
                    j,
                    i,
                    iterations[i, j],
                    ha="center",
                    va="center",
                    color="w",
                    fontsize=20,
                )

    if make_cbar:
        fig.subplots_adjust(right=0.8)
        cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
        fig.colorbar(im, cax=cbar_ax)
        cbar_ax.set_ylabel("mean_test_score", rotation=-90, va="bottom", fontsize=15)


fig, axes = plt.subplots(ncols=2, sharey=True)
ax1, ax2 = axes

make_heatmap(ax1, gsh, is_sh=True)
make_heatmap(ax2, gs, make_cbar=True)

ax1.set_title("Successive Halving\ntime = {:.3f}s".format(gsh_time), fontsize=15)
ax2.set_title("GridSearch\ntime = {:.3f}s".format(gs_time), fontsize=15)

plt.show()

요약

머신 러닝에서 인기 있는 두 가지 매개변수 검색 알고리즘인 그리드 검색 (Grid Search) 과 연속 반복 검색 (Successive Halving) 을 비교했습니다. 파이썬의 scikit-learn 라이브러리를 사용하여 비교를 수행했습니다. 합성 데이터셋을 생성하고 두 알고리즘 모두를 사용하여 SVC 모델에 대한 매개변수 검색을 수행했습니다. 그런 다음 히트맵을 사용하여 결과를 시각화했습니다. 연속 반복 검색 알고리즘이 그리드 검색과 마찬가지로 정확한 매개변수 조합을 훨씬 적은 시간에 찾을 수 있음을 확인했습니다.