KNN 分類結果の可視化
このボーナスステップでは、KNN 分類結果をより深く理解するために可視化を作成します。可視化により、モデルのパフォーマンスがどの程度であったかを確認し、KNN アルゴリズムによって作成された決定境界を理解することができます。
分類におけるデータ可視化の理解:
- 散布図: 特徴量間の関係とクラスの分布を示します。
- 色分け: 異なる色は異なるクラス (種) を表します。
- 訓練データ vs テストデータ: モデルの汎化を理解するのに役立ちます。
- 予測精度: 予測ラベルと実際のラベルの視覚的な比較。
可視化を作成するために、以下のコードを main.py ファイルの末尾に追加してください。
import matplotlib.pyplot as plt
import numpy as np
## 複数の可視化のためのサブプロットを作成
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
## プロット 1: 各クラスに異なる色を付けた訓練データの分布
scatter1 = axes[0].scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap='viridis', alpha=0.7, s=50)
axes[0].set_xlabel('Sepal Length (cm)')
axes[0].set_ylabel('Sepal Width (cm)')
axes[0].set_title('Training Data Distribution')
axes[0].legend(*scatter1.legend_elements(), title="Classes")
## プロット 2: テストデータの予測 vs 実際のラベル
## 比較を作成:正しい予測 vs 不正解な予測
test_predictions = clf.predict(X_test)
correct_predictions = (test_predictions == y_test)
## 正しい予測をプロット
correct_mask = correct_predictions
scatter2_correct = axes[1].scatter(X_test[correct_mask, 0], X_test[correct_mask, 1],
c=y_test[correct_mask], cmap='viridis', alpha=0.7, s=50, marker='o')
## 不正解な予測を異なるマーカーでプロット
incorrect_mask = ~correct_predictions
if np.any(incorrect_mask):
scatter2_incorrect = axes[1].scatter(X_test[incorrect_mask, 0], X_test[incorrect_mask, 1],
c=test_predictions[incorrect_mask], cmap='viridis',
alpha=0.7, s=80, marker='x', edgecolors='red', linewidths=2)
axes[1].set_xlabel('Sepal Length (cm)')
axes[1].set_ylabel('Sepal Width (cm)')
axes[1].set_title('Test Data: Predictions vs Actual\n(correct=●, incorrect=✕)')
## 凡例を作成
legend_elements = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', markersize=8, label='Correct'),
plt.Line2D([0], [0], marker='x', color='red', markersize=8, label='Incorrect')]
axes[1].legend(handles=legend_elements)
plt.tight_layout()
plt.savefig('knn_classification_results.png', dpi=150, bbox_inches='tight')
print("Visualization saved to knn_classification_results.png")
## 追加:予測精度を表示
accuracy = np.mean(correct_predictions) * 100
print(f"Model Accuracy: {accuracy:.1f}%")
更新されたスクリプトを実行します。
python3 main.py
以下のような出力が表示されるはずです。
... (previous output) ...
Visualization saved to knn_classification_results.png
Model Accuracy: 100.0%
可視化が示すもの:
- 左側のプロット: 実際の種ごとに色分けされた訓練データの分布。
- 右側のプロット: テストデータのポイントを示します。
- 円 (●): 正しく分類されたポイント。
- バツ (✕): 不正解に分類されたポイント (もしあれば)。
- 精度スコア: 正しい予測の全体的なパーセンテージ。
この可視化は、以下のことを理解するのに役立ちます。
- 特徴空間におけるクラスの分布。
- モデルが過学習しているか、うまく汎化しているか。
- 分類にとってどの領域が難しい可能性があるか。
- KNN モデルの効果を視覚的に。