선형 회귀 예제

Beginner

This tutorial is from open-source community. Access the source code

소개

이 실습에서는 선형 회귀를 사용하여 데이터 집합에 가장 잘 맞는 직선을 그리는 방법과 계수, 잔차 제곱합, 결정 계수를 계산하는 방법을 보여줍니다. scikit-learn 라이브러리를 사용하여 당뇨병 데이터 세트에 대한 선형 회귀를 수행할 것입니다.

VM 팁

VM 시작이 완료되면 왼쪽 상단 모서리를 클릭하여 Notebook 탭으로 전환하여 연습을 위한 Jupyter Notebook에 접근합니다.

때때로 Jupyter Notebook 이 완전히 로드되기까지 몇 초 정도 기다려야 할 수 있습니다. Jupyter Notebook 의 제한으로 인해 작업의 유효성 검사를 자동화할 수 없습니다.

학습 중 문제가 발생하면 Labby 에게 문의하십시오. 세션 후 피드백을 제공하면 문제를 신속하게 해결해 드리겠습니다.

당뇨병 데이터셋 로드

scikit-learn 에서 당뇨병 데이터셋을 로드하고 데이터셋에서 하나의 특징만 선택하여 시작합니다.

import numpy as np
from sklearn import datasets

## 당뇨병 데이터셋 로드
diabetes_X, diabetes_y = datasets.load_diabetes(return_X_y=True)

## 하나의 특징만 사용
diabetes_X = diabetes_X[:, np.newaxis, 2]

데이터셋 분할

다음으로, 데이터셋을 학습용 및 테스트용 데이터셋으로 분할합니다. 학습에는 데이터의 80%, 테스트에는 20% 를 사용할 것입니다.

## 데이터를 학습/테스트 세트로 분할
diabetes_X_train = diabetes_X[:-20]
diabetes_X_test = diabetes_X[-20:]

## 대상 변수를 학습/테스트 세트로 분할
diabetes_y_train = diabetes_y[:-20]
diabetes_y_test = diabetes_y[-20:]

모델 학습

이제 선형 회귀 객체를 생성하고 학습 데이터셋을 사용하여 모델을 학습합니다.

from sklearn import linear_model

## 선형 회귀 객체 생성
regr = linear_model.LinearRegression()

## 학습 데이터셋을 사용하여 모델 학습
regr.fit(diabetes_X_train, diabetes_y_train)

예측 수행

이제 학습된 모델을 사용하여 테스트 데이터셋에 대한 예측을 수행할 수 있습니다.

## 테스트 데이터셋을 사용하여 예측 수행
diabetes_y_pred = regr.predict(diabetes_X_test)

지표 계산

계수, 평균 제곱 오차, 결정 계수를 계산할 수 있습니다.

from sklearn.metrics import mean_squared_error, r2_score

## 계수
print("계수: \n", regr.coef_)

## 평균 제곱 오차
print("평균 제곱 오차: %.2f"
      % mean_squared_error(diabetes_y_test, diabetes_y_pred))

## 결정 계수: 1 은 완벽한 예측
print("결정 계수: %.2f"
      % r2_score(diabetes_y_test, diabetes_y_pred))

결과 시각화

마지막으로, 예측 값과 실제 값을 플롯하여 모델이 데이터에 얼마나 잘 맞는지 시각적으로 확인할 수 있습니다.

import matplotlib.pyplot as plt

## 출력 플롯
plt.scatter(diabetes_X_test, diabetes_y_test, color="black")
plt.plot(diabetes_X_test, diabetes_y_pred, color="blue", linewidth=3)

plt.xticks(())
plt.yticks(())

plt.show()

요약

이 실습에서는 선형 회귀를 사용하여 데이터 세트에 직선을 맞추는 방법과 계수, 잔차 제곱합, 결정 계수를 계산하는 방법을 배웠습니다. 또한 산점도를 사용하여 예측 값과 실제 값을 시각적으로 비교하는 방법도 배웠습니다.