Scikit-Learn 시각화 API

Beginner

This tutorial is from open-source community. Access the source code

소개

Scikit-learn 은 머신러닝 작업을 위한 간단하고 효율적인 API 를 제공하는 인기 있는 Python 라이브러리입니다. Scikit-learn 의 주요 기능 중 하나는 머신러닝 모델에 대한 시각화를 쉽게 생성할 수 있도록 하는 내장 시각화 API 입니다. 이 실습에서는 Scikit-learn 시각화 API 를 사용하여 두 개의 서로 다른 분류기의 ROC 곡선을 비교하는 방법을 살펴볼 것입니다.

VM 팁

VM 시작이 완료되면 왼쪽 상단 모서리를 클릭하여 Notebook 탭으로 전환하여 연습을 위한 Jupyter Notebook에 접속합니다.

때때로 Jupyter Notebook 이 완전히 로드되기까지 몇 초 정도 기다려야 할 수 있습니다. Jupyter Notebook 의 제한으로 인해 작업의 유효성 검사를 자동화할 수 없습니다.

학습 중 문제가 발생하면 Labby 에게 문의하십시오. 세션 후 피드백을 제공하면 문제를 신속하게 해결해 드리겠습니다.

데이터 로드 및 SVC 학습

먼저 와인 데이터셋을 로드하고 이를 이진 분류 문제로 변환합니다. 그런 다음 훈련 데이터셋에 서포트 벡터 분류기를 학습시킵니다.

import matplotlib.pyplot as plt
from sklearn.svm import SVC
from sklearn.metrics import RocCurveDisplay
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split

X, y = load_wine(return_X_y=True)
y = y == 2

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
svc = SVC(random_state=42)
svc.fit(X_train, y_train)

ROC 곡선 그리기

다음으로 RocCurveDisplay.from_estimator 함수를 사용하여 ROC 곡선을 그립니다. 이 함수는 학습된 분류기, 테스트 데이터셋, 실제 레이블을 입력으로 받아 ROC 곡선을 그릴 수 있는 객체를 반환합니다. 그런 다음 show() 메서드를 호출하여 플롯을 표시합니다.

svc_disp = RocCurveDisplay.from_estimator(svc, X_test, y_test)
svc_disp.show()

랜덤 포레스트 학습 및 ROC 곡선 그리기

이 단계에서는 랜덤 포레스트 분류기를 학습하고 SVC ROC 곡선과 함께 랜덤 포레스트의 ROC 곡선을 그립니다. 이를 위해 새로운 RandomForestClassifier 객체를 생성하고 훈련 데이터에 맞춥니다. 그런 다음 이 분류기를 사용하여 새로운 RocCurveDisplay 객체를 생성합니다. 또한 동일한 축에 곡선을 그리기 위해 ax 매개변수를 이 함수에 전달합니다. 마지막으로 svc_disp 객체의 plot() 메서드를 호출하여 SVC ROC 곡선을 그립니다.

from sklearn.ensemble import RandomForestClassifier

rfc = RandomForestClassifier(n_estimators=10, random_state=42)
rfc.fit(X_train, y_train)

ax = plt.gca()
rfc_disp = RocCurveDisplay.from_estimator(rfc, X_test, y_test, ax=ax, alpha=0.8)
svc_disp.plot(ax=ax, alpha=0.8)
plt.show()

요약

이 실습에서는 scikit-learn 시각화 API 를 사용하여 두 가지 다른 분류기의 ROC 곡선을 그리는 방법을 탐색했습니다. 와인 데이터셋을 로드하고 훈련 데이터에 대한 서포트 벡터 분류기를 훈련하는 것으로 시작했습니다. 그런 다음 RocCurveDisplay 함수를 사용하여 이 분류기의 ROC 곡선을 그렸습니다. 마지막으로 랜덤 포레스트 분류기를 훈련하고 SVC ROC 곡선과 함께 랜덤 포레스트의 ROC 곡선을 그렸습니다. scikit-learn 시각화 API 를 사용하면 서로 다른 분류기를 비교하고 성능을 시각화하는 것이 쉽습니다.