Scikit - Learn を使った交差検証手法

Beginner

This tutorial is from open-source community. Access the source code

はじめに

交差検証は、利用可能なデータの異なるサブセットでいくつかのモデルを学習させ、相補的なサブセットで評価することにより、機械学習モデルを評価する手法です。これにより、過学習を回避し、モデルが汎用性を持つことが保証されます。このチュートリアルでは、scikit-learn を使って、さまざまな交差検証手法とその挙動を調べます。

VM のヒント

VM の起動が完了したら、左上隅をクリックしてノートブックタブに切り替え、Jupyter Notebook を使って練習しましょう。

時々、Jupyter Notebook が読み込み終わるまで数秒待つ必要があります。Jupyter Notebook の制限により、操作の検証を自動化することはできません。

学習中に問題に遭遇した場合は、Labby にお問い合わせください。セッション後にフィードバックを提供してください。すぐに問題を解決いたします。

データを可視化する

このステップでは、扱うデータを可視化します。データは、100 個のランダムに生成された入力データポイント、データポイントに不均一に分割された 3 つのクラス、およびデータポイントに均等に分割された 10 個の「グループ」で構成されています。

## クラス/グループデータを生成する
n_points = 100
X = rng.randn(100, 10)

percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])

## 不均一なグループを生成する
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))


def visualize_groups(classes, groups, name):
    ## データセットのグループを可視化する
    fig, ax = plt.subplots()
    ax.scatter(
        range(len(groups)),
        [0.5] * len(groups),
        c=groups,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.scatter(
        range(len(groups)),
        [3.5] * len(groups),
        c=classes,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.set(
        ylim=[-1, 5],
        yticks=[0.5, 3.5],
        yticklabels=["Data\ngroup", "Data\nclass"],
        xlabel="Sample index",
    )


visualize_groups(y, groups, "no groups")

交差検証インデックスを可視化する

このステップでは、各交差検証オブジェクトの挙動を可視化する関数を定義します。データを 4 分割します。各分割で、学習セット(青)とテストセット(赤)に選択されたインデックスを可視化します。

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    """交差検証オブジェクトのインデックスのサンプルプロットを作成する。"""

    ## 各 CV 分割に対する学習/テストの可視化を生成する
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        ## 学習/テストグループでインデックスを埋める
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        ## 結果を可視化する
        ax.scatter(
            range(len(indices)),
            [ii + 0.5] * len(indices),
            c=indices,
            marker="_",
            lw=lw,
            cmap=cmap_cv,
            vmin=-0.2,
            vmax=1.2,
        )

    ## 最後にデータのクラスとグループをプロットする
    ax.scatter(
        range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
    )

    ax.scatter(
        range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
    )

    ## フォーマット設定
    yticklabels = list(range(n_splits)) + ["class", "group"]
    ax.set(
        yticks=np.arange(n_splits + 2) + 0.5,
        yticklabels=yticklabels,
        xlabel="Sample index",
        ylabel="CV iteration",
        ylim=[n_splits + 2.2, -0.2],
        xlim=[0, 100],
    )
    ax.set_title("{}".format(type(cv).__name__), fontsize=15)
    return ax

交差検証オブジェクトを比較する

このステップでは、異なる scikit-learn の交差検証オブジェクトの交差検証の挙動を比較します。いくつかの一般的な交差検証オブジェクトをループ処理し、それぞれの挙動を可視化します。どのようにいくつかはグループ/クラス情報を使用し、他のものは使用しないかをご注目ください。

cvs = [
    KFold,
    GroupKFold,
    ShuffleSplit,
    StratifiedKFold,
    StratifiedGroupKFold,
    GroupShuffleSplit,
    StratifiedShuffleSplit,
    TimeSeriesSplit,
]

for cv in cvs:
    this_cv = cv(n_splits=n_splits)
    fig, ax = plt.subplots(figsize=(6, 3))
    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend(
        [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
        ["Testing set", "Training set"],
        loc=(1.02, 0.8),
    )
    ## 凡例を収まるようにする
    plt.tight_layout()
    fig.subplots_adjust(right=0.7)
plt.show()

まとめ

このチュートリアルでは、scikit-learn を使って、さまざまな交差検証手法とその挙動を調べました。データを可視化し、交差検証インデックスを可視化する関数を定義し、異なる scikit-learn の交差検証オブジェクトの交差検証の挙動を比較しました。