데이터 포인트 생성
3 개의 클래스에서 9 개의 샘플로 구성된 데이터 세트를 생성하고 원래 공간에서 점을 플롯하는 것으로 시작합니다. 이 예제에서는 점 번호 3 의 분류에 중점을 둡니다.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.neighbors import NeighborhoodComponentsAnalysis
from matplotlib import cm
from scipy.special import logsumexp
X, y = make_classification(
n_samples=9,
n_features=2,
n_informative=2,
n_redundant=0,
n_classes=3,
n_clusters_per_class=1,
class_sep=1.0,
random_state=0,
)
plt.figure(1)
ax = plt.gca()
for i in range(X.shape[0]):
ax.text(X[i, 0], X[i, 1], str(i), va="center", ha="center")
ax.scatter(X[i, 0], X[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)
ax.set_title("Original points")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.axis("equal") ## 원형 경계가 올바르게 표시되도록