近傍成分分析

Beginner

This tutorial is from open-source community. Access the source code

はじめに

この実験では、最も近い近傍の分類精度を最大化する距離メトリックを学習する際の近傍成分分析(Neighborhood Components Analysis:NCA)の使用方法を示すことを目的としています。これは、元の点空間と比較したこのメトリックの視覚的な表現を提供します。

VM のヒント

VM の起動が完了したら、左上隅をクリックしてノートブックタブに切り替え、Jupyter Notebook を使って練習しましょう。

Jupyter Notebook が読み込み終わるまで数秒待つことがあります。Jupyter Notebook の制限により、操作の検証を自動化することはできません。

学習中に問題がある場合は、Labby にお問い合わせください。セッション後にフィードバックを提供してください。すぐに問題を解決いたします。

データポイントの生成

まず、3 つのクラスからなる 9 つのサンプルのデータセットを生成し、元の空間における点をプロットします。この例では、点番号 3 の分類に焦点を当てます。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.neighbors import NeighborhoodComponentsAnalysis
from matplotlib import cm
from scipy.special import logsumexp

X, y = make_classification(
    n_samples=9,
    n_features=2,
    n_informative=2,
    n_redundant=0,
    n_classes=3,
    n_clusters_per_class=1,
    class_sep=1.0,
    random_state=0,
)

plt.figure(1)
ax = plt.gca()
for i in range(X.shape[0]):
    ax.text(X[i, 0], X[i, 1], str(i), va="center", ha="center")
    ax.scatter(X[i, 0], X[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)

ax.set_title("Original points")
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
ax.axis("equal")  ## so that boundaries are displayed correctly as circles

近傍を可視化する

次に、データポイント間のリンクを可視化します。点番号 3 と他の点の間のリンクの太さは、それらの距離に比例します。

def link_thickness_i(X, i):
    diff_embedded = X[i] - X
    dist_embedded = np.einsum("ij,ij->i", diff_embedded, diff_embedded)
    dist_embedded[i] = np.inf

    ## compute exponentiated distances (use the log-sum-exp trick to
    ## avoid numerical instabilities
    exp_dist_embedded = np.exp(-dist_embedded - logsumexp(-dist_embedded))
    return exp_dist_embedded


def relate_point(X, i, ax):
    pt_i = X[i]
    for j, pt_j in enumerate(X):
        thickness = link_thickness_i(X, i)
        if i!= j:
            line = ([pt_i[0], pt_j[0]], [pt_i[1], pt_j[1]])
            ax.plot(*line, c=cm.Set1(y[j]), linewidth=5 * thickness[j])


i = 3
relate_point(X, i, ax)
plt.show()

埋め込みを学習する

次に、NCA を使って埋め込みを学習し、変換後の点をプロットします。その後、埋め込みを使って最も近い近傍を見つけます。

nca = NeighborhoodComponentsAnalysis(max_iter=30, random_state=0)
nca = nca.fit(X, y)

plt.figure(2)
ax2 = plt.gca()
X_embedded = nca.transform(X)
relate_point(X_embedded, i, ax2)

for i in range(len(X)):
    ax2.text(X_embedded[i, 0], X_embedded[i, 1], str(i), va="center", ha="center")
    ax2.scatter(X_embedded[i, 0], X_embedded[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)

ax2.set_title("NCA embedding")
ax2.axes.get_xaxis().set_visible(False)
ax2.axes.get_yaxis().set_visible(False)
ax2.axis("equal")
plt.show()

まとめ

この実験では、最も近い近傍の分類精度を最大化する距離メトリックを学習するために NCA をどのように使用するかを示しました。データポイント間のリンクを可視化し、元の点空間と変換後の空間を比較しました。