소개
교차 검증 (Cross-validation) 은 사용 가능한 데이터의 서브셋에서 여러 모델을 학습시키고 보완적인 서브셋으로 평가하여 머신 러닝 모델을 평가하는 기술입니다. 이를 통해 과적합을 방지하고 모델의 일반화 가능성을 확보하는 데 도움이 됩니다. 이 튜토리얼에서는 scikit-learn 을 사용하여 다양한 교차 검증 기법과 그 동작 방식을 살펴볼 것입니다.
VM 팁
VM 시작이 완료되면 왼쪽 상단 모서리를 클릭하여 Notebook 탭으로 전환하여 연습을 위한 Jupyter Notebook에 접근하십시오.
때때로 Jupyter Notebook 이 완전히 로드되기까지 몇 초 정도 기다려야 할 수 있습니다. Jupyter Notebook 의 제한으로 인해 작업의 유효성 검사를 자동화할 수 없습니다.
학습 중 문제가 발생하면 Labby 에 문의하십시오. 세션 후 피드백을 제공하면 문제를 신속하게 해결해 드리겠습니다.
데이터 시각화
이 단계에서는 작업할 데이터를 시각화합니다. 데이터는 100 개의 임의로 생성된 입력 데이터 포인트, 데이터 포인트에 불균형적으로 분포된 3 개의 클래스, 그리고 데이터 포인트에 균등하게 분포된 10 개의 "그룹"으로 구성됩니다.
## 클래스/그룹 데이터 생성
n_points = 100
X = rng.randn(100, 10)
percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])
## 불균등한 그룹 생성
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))
def visualize_groups(classes, groups, name):
## 데이터셋 그룹 시각화
fig, ax = plt.subplots()
ax.scatter(
range(len(groups)),
[0.5] * len(groups),
c=groups,
marker="_",
lw=50,
cmap=cmap_data,
)
ax.scatter(
range(len(groups)),
[3.5] * len(groups),
c=classes,
marker="_",
lw=50,
cmap=cmap_data,
)
ax.set(
ylim=[-1, 5],
yticks=[0.5, 3.5],
yticklabels=["데이터\n그룹", "데이터\n클래스"],
xlabel="샘플 인덱스",
)
visualize_groups(y, groups, "no groups")
교차 검증 지표 시각화
이 단계에서는 각 교차 검증 객체의 동작을 시각화하는 함수를 정의합니다. 데이터를 4 개의 분할로 나눕니다. 각 분할에서 학습 세트 (파란색) 와 테스트 세트 (빨간색) 에 선택된 인덱스를 시각화합니다.
def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
"""교차 검증 객체의 인덱스에 대한 샘플 플롯을 만듭니다."""
## 각 CV 분할에 대한 학습/테스트 시각화 생성
for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
## 학습/테스트 그룹으로 인덱스 채우기
indices = np.array([np.nan] * len(X))
indices[tt] = 1
indices[tr] = 0
## 결과 시각화
ax.scatter(
range(len(indices)),
[ii + 0.5] * len(indices),
c=indices,
marker="_",
lw=lw,
cmap=cmap_cv,
vmin=-0.2,
vmax=1.2,
)
## 마지막에 데이터 클래스와 그룹 플롯
ax.scatter(
range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
)
ax.scatter(
range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
)
## 서식 지정
yticklabels = list(range(n_splits)) + ["class", "group"]
ax.set(
yticks=np.arange(n_splits + 2) + 0.5,
yticklabels=yticklabels,
xlabel="샘플 인덱스",
ylabel="CV 반복",
ylim=[n_splits + 2.2, -0.2],
xlim=[0, 100],
)
ax.set_title("{}".format(type(cv).__name__), fontsize=15)
return ax
교차 검증 객체 비교
이 단계에서는 scikit-learn 교차 검증 객체의 서로 다른 동작을 비교합니다. 여러 가지 일반적인 교차 검증 객체를 반복하여 각 객체의 동작을 시각화합니다. 일부 객체는 그룹/클래스 정보를 사용하는 반면, 다른 객체는 사용하지 않는 점에 유의하십시오.
cvs = [
KFold,
GroupKFold,
ShuffleSplit,
StratifiedKFold,
StratifiedGroupKFold,
GroupShuffleSplit,
StratifiedShuffleSplit,
TimeSeriesSplit,
]
for cv in cvs:
this_cv = cv(n_splits=n_splits)
fig, ax = plt.subplots(figsize=(6, 3))
plot_cv_indices(this_cv, X, y, groups, ax, n_splits)
ax.legend(
[Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
["테스트 세트", "학습 세트"],
loc=(1.02, 0.8),
)
## 범례 크기 조정
plt.tight_layout()
fig.subplots_adjust(right=0.7)
plt.show()
요약
이 튜토리얼에서는 scikit-learn 을 사용하여 다양한 교차 검증 기법과 그 동작을 탐색했습니다. 데이터를 시각화하고, 교차 검증 인덱스를 시각화하는 함수를 정의했으며, 서로 다른 scikit-learn 교차 검증 객체의 동작을 비교했습니다.