はじめに
この実験では、scikit-learn を使って Python で 3 クラス分類におけるシグモイド補正をどのように使用するかを示します。シグモイド補正が予測確率をどのように変えるか、およびモデルの性能を向上させるためにどのように使用できるかを示します。
VM のヒント
VM の起動が完了した後、左上隅をクリックして ノートブック タブに切り替えて、Jupyter Notebook を使った練習にアクセスします。
時々、Jupyter Notebook が読み込み完了するまで数秒待つ必要がある場合があります。Jupyter Notebook の制限により、操作の検証を自動化することはできません。
学習中に問題に直面した場合は、Labby にお問い合わせください。セッション後にフィードバックを提供してください。そうすると、迅速に問題を解決します。
データ
2000 個のサンプル、2 つの特徴量、および 3 つのターゲット クラスを持つ分類用のデータセットを生成します。次に、データを以下のように分割します。
- 学習用:600 個のサンプル (分類器の学習に使用)
- 検証用:400 個のサンプル (予測確率の補正に使用)
- テスト用:1000 個のサンプル
import numpy as np
from sklearn.datasets import make_blobs
np.random.seed(0)
X, y = make_blobs(
n_samples=2000, n_features=2, centers=3, random_state=42, cluster_std=5.0
)
X_train, y_train = X[:600], y[:600]
X_valid, y_valid = X[600:1000], y[600:1000]
X_train_valid, y_train_valid = X[:1000], y[:1000]
X_test, y_test = X[1000:], y[1000:]
フィッティングと補正
学習用と検証用のデータ (1000 個のサンプル) を連結して、25 個のベース推定器 (木) を持つランダム フォレスト分類器を学習します。これが補正前の分類器です。
from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier(n_estimators=25)
clf.fit(X_train_valid, y_train_valid)
補正済みの分類器を学習するには、同じランダム フォレスト分類器から始めて、学習用データサブセット (600 個のサンプル) のみを使用して学習し、次に 2 段階のプロセスで検証用データサブセット (400 個のサンプル) を使用して method='sigmoid' で補正します。
from sklearn.calibration import CalibratedClassifierCV
clf = RandomForestClassifier(n_estimators=25)
clf.fit(X_train, y_train)
cal_clf = CalibratedClassifierCV(clf, method="sigmoid", cv="prefit")
cal_clf.fit(X_valid, y_valid)
確率を比較する
テスト サンプルの予測確率の変化を矢印で示した 2 次元単純形を描画します。
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
colors = ["r", "g", "b"]
clf_probs = clf.predict_proba(X_test)
cal_clf_probs = cal_clf.predict_proba(X_test)
## Plot arrows
for i in range(clf_probs.shape[0]):
plt.arrow(
clf_probs[i, 0],
clf_probs[i, 1],
cal_clf_probs[i, 0] - clf_probs[i, 0],
cal_clf_probs[i, 1] - clf_probs[i, 1],
color=colors[y_test[i]],
head_width=1e-2,
)
## Plot perfect predictions, at each vertex
plt.plot([1.0], [0.0], "ro", ms=20, label="Class 1")
plt.plot([0.0], [1.0], "go", ms=20, label="Class 2")
plt.plot([0.0], [0.0], "bo", ms=20, label="Class 3")
## Plot boundaries of unit simplex
plt.plot([0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], "k", label="Simplex")
## Annotate points 6 points around the simplex, and mid point inside simplex
plt.annotate(
r"($\frac{1}{3}$, $\frac{1}{3}$, $\frac{1}{3}$)",
xy=(1.0 / 3, 1.0 / 3),
xytext=(1.0 / 3, 0.23),
xycoords="data",
arrowprops=dict(facecolor="black", shrink=0.05),
horizontalalignment="center",
verticalalignment="center",
)
plt.plot([1.0 / 3], [1.0 / 3], "ko", ms=5)
plt.annotate(
r"($\frac{1}{2}$, $0$, $\frac{1}{2}$)",
xy=(0.5, 0.0),
xytext=(0.5, 0.1),
xycoords="data",
arrowprops=dict(facecolor="black", shrink=0.05),
horizontalalignment="center",
verticalalignment="center",
)
plt.annotate(
r"($0$, $\frac{1}{2}$, $\frac{1}{2}$)",
xy=(0.0, 0.5),
xytext=(0.1, 0.5),
xycoords="data",
arrowprops=dict(facecolor="black", shrink=0.05),
horizontalalignment="center",
verticalalignment="center",
)
plt.annotate(
r"($\frac{1}{2}$, $\frac{1}{2}$, $0$)",
xy=(0.5, 0.5),
xytext=(0.6, 0.6),
xycoords="data",
arrowprops=dict(facecolor="black", shrink=0.05),
horizontalalignment="center",
verticalalignment="center",
)
plt.annotate(
r"($0$, $0$, $1$)",
xy=(0, 0),
xytext=(0.1, 0.1),
xycoords="data",
arrowprops=dict(facecolor="black", shrink=0.05),
horizontalalignment="center",
verticalalignment="center",
)
plt.annotate(
r"($1$, $0$, $0$)",
xy=(1, 0),
xytext=(1, 0.1),
xycoords="data",
arrowprops=dict(facecolor="black", shrink=0.05),
horizontalalignment="center",
verticalalignment="center",
)
plt.annotate(
r"($0$, $1$, $0$)",
xy=(0, 1),
xytext=(0.1, 1),
xycoords="data",
arrowprops=dict(facecolor="black", shrink=0.05),
horizontalalignment="center",
verticalalignment="center",
)
## Add grid
plt.grid(False)
for x in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
plt.plot([0, x], [x, 0], "k", alpha=0.2)
plt.plot([0, 0 + (1 - x) / 2], [x, x + (1 - x) / 2], "k", alpha=0.2)
plt.plot([x, x + (1 - x) / 2], [0, 0 + (1 - x) / 2], "k", alpha=0.2)
plt.title("Change of predicted probabilities on test samples after sigmoid calibration")
plt.xlabel("Probability class 1")
plt.ylabel("Probability class 2")
plt.xlim(-0.05, 1.05)
plt.ylim(-0.05, 1.05)
_ = plt.legend(loc="best")
ログ損失の比較
1000 個のテスト サンプルの予測に対する、補正前と補正済みの分類器のログ損失を比較します。
from sklearn.metrics import log_loss
score = log_loss(y_test, clf_probs)
cal_score = log_loss(y_test, cal_clf_probs)
print("Log-loss of")
print(f" * uncalibrated classifier: {score:.3f}")
print(f" * calibrated classifier: {cal_score:.3f}")
グリッドを生成して描画する
2 次元単純形上の補正前の確率の可能なグリッドを生成し、対応する補正済みの確率を計算し、それぞれに矢印を描画します。矢印は、最も高い補正前の確率に応じて色付けされます。これは学習された補正マップを示しています。
plt.figure(figsize=(10, 10))
## Generate grid of probability values
p1d = np.linspace(0, 1, 20)
p0, p1 = np.meshgrid(p1d, p1d)
p2 = 1 - p0 - p1
p = np.c_[p0.ravel(), p1.ravel(), p2.ravel()]
p = p[p[:, 2] >= 0]
## Use the three class-wise calibrators to compute calibrated probabilities
calibrated_classifier = cal_clf.calibrated_classifiers_[0]
prediction = np.vstack(
[
calibrator.predict(this_p)
for calibrator, this_p in zip(calibrated_classifier.calibrators, p.T)
]
).T
## Re-normalize the calibrated predictions to make sure they stay inside the
## simplex. This same renormalization step is performed internally by the
## predict method of CalibratedClassifierCV on multiclass problems.
prediction /= prediction.sum(axis=1)[:, None]
## Plot changes in predicted probabilities induced by the calibrators
for i in range(prediction.shape[0]):
plt.arrow(
p[i, 0],
p[i, 1],
prediction[i, 0] - p[i, 0],
prediction[i, 1] - p[i, 1],
head_width=1e-2,
color=colors[np.argmax(p[i])],
)
## Plot the boundaries of the unit simplex
plt.plot([0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], "k", label="Simplex")
plt.grid(False)
for x in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
plt.plot([0, x], [x, 0], "k", alpha=0.2)
plt.plot([0, 0 + (1 - x) / 2], [x, x + (1 - x) / 2], "k", alpha=0.2)
plt.plot([x, x + (1 - x) / 2], [0, 0 + (1 - x) / 2], "k", alpha=0.2)
plt.title("Learned sigmoid calibration map")
plt.xlabel("Probability class 1")
plt.ylabel("Probability class 2")
plt.xlim(-0.05, 1.05)
plt.ylim(-0.05, 1.05)
plt.show()
まとめ
この実験では、scikit-learn を使って Python で 3 クラス分類におけるシグモイド補正をどのように使用するかを示しました。また、補正が予測確率に与える影響と、モデルの性能を向上させるためにどのように使用できるかを示しました。