Scikit-Learn 을 이용한 교차 검증 기법

Beginner

This tutorial is from open-source community. Access the source code

소개

교차 검증 (Cross-validation) 은 사용 가능한 데이터의 서브셋에서 여러 모델을 학습시키고 보완적인 서브셋으로 평가하여 머신 러닝 모델을 평가하는 기술입니다. 이를 통해 과적합을 방지하고 모델의 일반화 가능성을 확보하는 데 도움이 됩니다. 이 튜토리얼에서는 scikit-learn 을 사용하여 다양한 교차 검증 기법과 그 동작 방식을 살펴볼 것입니다.

VM 팁

VM 시작이 완료되면 왼쪽 상단 모서리를 클릭하여 Notebook 탭으로 전환하여 연습을 위한 Jupyter Notebook에 접근하십시오.

때때로 Jupyter Notebook 이 완전히 로드되기까지 몇 초 정도 기다려야 할 수 있습니다. Jupyter Notebook 의 제한으로 인해 작업의 유효성 검사를 자동화할 수 없습니다.

학습 중 문제가 발생하면 Labby 에 문의하십시오. 세션 후 피드백을 제공하면 문제를 신속하게 해결해 드리겠습니다.

데이터 시각화

이 단계에서는 작업할 데이터를 시각화합니다. 데이터는 100 개의 임의로 생성된 입력 데이터 포인트, 데이터 포인트에 불균형적으로 분포된 3 개의 클래스, 그리고 데이터 포인트에 균등하게 분포된 10 개의 "그룹"으로 구성됩니다.

## 클래스/그룹 데이터 생성
n_points = 100
X = rng.randn(100, 10)

percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])

## 불균등한 그룹 생성
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))


def visualize_groups(classes, groups, name):
    ## 데이터셋 그룹 시각화
    fig, ax = plt.subplots()
    ax.scatter(
        range(len(groups)),
        [0.5] * len(groups),
        c=groups,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.scatter(
        range(len(groups)),
        [3.5] * len(groups),
        c=classes,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.set(
        ylim=[-1, 5],
        yticks=[0.5, 3.5],
        yticklabels=["데이터\n그룹", "데이터\n클래스"],
        xlabel="샘플 인덱스",
    )


visualize_groups(y, groups, "no groups")

교차 검증 지표 시각화

이 단계에서는 각 교차 검증 객체의 동작을 시각화하는 함수를 정의합니다. 데이터를 4 개의 분할로 나눕니다. 각 분할에서 학습 세트 (파란색) 와 테스트 세트 (빨간색) 에 선택된 인덱스를 시각화합니다.

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    """교차 검증 객체의 인덱스에 대한 샘플 플롯을 만듭니다."""

    ## 각 CV 분할에 대한 학습/테스트 시각화 생성
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        ## 학습/테스트 그룹으로 인덱스 채우기
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        ## 결과 시각화
        ax.scatter(
            range(len(indices)),
            [ii + 0.5] * len(indices),
            c=indices,
            marker="_",
            lw=lw,
            cmap=cmap_cv,
            vmin=-0.2,
            vmax=1.2,
        )

    ## 마지막에 데이터 클래스와 그룹 플롯
    ax.scatter(
        range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
    )

    ax.scatter(
        range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
    )

    ## 서식 지정
    yticklabels = list(range(n_splits)) + ["class", "group"]
    ax.set(
        yticks=np.arange(n_splits + 2) + 0.5,
        yticklabels=yticklabels,
        xlabel="샘플 인덱스",
        ylabel="CV 반복",
        ylim=[n_splits + 2.2, -0.2],
        xlim=[0, 100],
    )
    ax.set_title("{}".format(type(cv).__name__), fontsize=15)
    return ax

교차 검증 객체 비교

이 단계에서는 scikit-learn 교차 검증 객체의 서로 다른 동작을 비교합니다. 여러 가지 일반적인 교차 검증 객체를 반복하여 각 객체의 동작을 시각화합니다. 일부 객체는 그룹/클래스 정보를 사용하는 반면, 다른 객체는 사용하지 않는 점에 유의하십시오.

cvs = [
    KFold,
    GroupKFold,
    ShuffleSplit,
    StratifiedKFold,
    StratifiedGroupKFold,
    GroupShuffleSplit,
    StratifiedShuffleSplit,
    TimeSeriesSplit,
]

for cv in cvs:
    this_cv = cv(n_splits=n_splits)
    fig, ax = plt.subplots(figsize=(6, 3))
    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend(
        [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
        ["테스트 세트", "학습 세트"],
        loc=(1.02, 0.8),
    )
    ## 범례 크기 조정
    plt.tight_layout()
    fig.subplots_adjust(right=0.7)
plt.show()

요약

이 튜토리얼에서는 scikit-learn 을 사용하여 다양한 교차 검증 기법과 그 동작을 탐색했습니다. 데이터를 시각화하고, 교차 검증 인덱스를 시각화하는 함수를 정의했으며, 서로 다른 scikit-learn 교차 검증 객체의 동작을 비교했습니다.