이웃 성분 분석

Beginner

This tutorial is from open-source community. Access the source code

소개

이 실험실은 최근접 이웃 분류 정확도를 극대화하는 거리 측정값을 학습하는 데 이웃 성분 분석 (NCA) 의 사용을 보여주는 것을 목표로 합니다. 이 측정값을 원래 점 공간과 비교하여 시각적으로 나타냅니다.

가상 머신 팁

가상 머신 시작이 완료되면 왼쪽 상단 모서리를 클릭하여 연습을 위해 Jupyter Notebook에 접속합니다.

때때로 Jupyter Notebook 이 완전히 로드되기까지 몇 초 정도 기다려야 할 수 있습니다. Jupyter Notebook 의 제한으로 인해 작업 검증을 자동화할 수 없습니다.

학습 중 문제가 발생하면 Labby 에 문의하십시오. 세션 후 피드백을 제공하면 문제를 신속하게 해결해 드리겠습니다.

데이터 포인트 생성

3 개의 클래스에서 9 개의 샘플로 구성된 데이터 세트를 생성하고 원래 공간에서 점을 플롯하는 것으로 시작합니다. 이 예제에서는 점 번호 3 의 분류에 중점을 둡니다.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.neighbors import NeighborhoodComponentsAnalysis
from matplotlib import cm
from scipy.special import logsumexp

X, y = make_classification(
    n_samples=9,
    n_features=2,
    n_informative=2,
    n_redundant=0,
    n_classes=3,
    n_clusters_per_class=1,
    class_sep=1.0,
    random_state=0,
)

plt.figure(1)
ax = plt.gca()
for i in range(X.shape[0]):
    ax.text(X[i, 0], X[i, 1], str(i), va="center", ha="center")
    ax.scatter(X[i, 0], X[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)

ax.set_title("Original points")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.axis("equal")  ## 원형 경계가 올바르게 표시되도록

이웃 시각화

이제 데이터 포인트 간의 연결을 시각화합니다. 점 번호 3 과 다른 점 사이의 연결 두께는 두 점 사이의 거리에 비례합니다.

def link_thickness_i(X, i):
    diff_embedded = X[i] - X
    dist_embedded = np.einsum("ij,ij->i", diff_embedded, diff_embedded)
    dist_embedded[i] = np.inf

    ## 지수 거리 계산 (로그 - 합-지수 트릭 사용하여
    ## 수치적 불안정성 방지
    exp_dist_embedded = np.exp(-dist_embedded - logsumexp(-dist_embedded))
    return exp_dist_embedded


def relate_point(X, i, ax):
    pt_i = X[i]
    for j, pt_j in enumerate(X):
        thickness = link_thickness_i(X, i)
        if i != j:
            line = ([pt_i[0], pt_j[0]], [pt_i[1], pt_j[1]])
            ax.plot(*line, c=cm.Set1(y[j]), linewidth=5 * thickness[j])


i = 3
relate_point(X, i, ax)
plt.show()

임베딩 학습

이제 NCA 를 사용하여 임베딩을 학습하고 변환 후 점을 플롯합니다. 그런 다음 임베딩을 사용하여 가장 가까운 이웃을 찾습니다.

nca = NeighborhoodComponentsAnalysis(max_iter=30, random_state=0)
nca = nca.fit(X, y)

plt.figure(2)
ax2 = plt.gca()
X_embedded = nca.transform(X)
relate_point(X_embedded, i, ax2)

for i in range(len(X)):
    ax2.text(X_embedded[i, 0], X_embedded[i, 1], str(i), va="center", ha="center")
    ax2.scatter(X_embedded[i, 0], X_embedded[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)

ax2.set_title("NCA 임베딩")
ax2.axes.get_xaxis().set_visible(False)
ax2.axes.get_yaxis().set_visible(False)
ax2.axis("equal")
plt.show()

요약

이 실험에서 우리는 NCA 를 사용하여 가장 가까운 이웃 분류 정확도를 극대화하는 거리 측정값을 학습하는 방법을 보여주었습니다. 데이터 포인트 간의 연결을 시각화하고 원래 점 공간과 변환된 공간을 비교했습니다.