使用 Scikit-Learn 的交叉验证技术

Beginner

This tutorial is from open-source community. Access the source code

简介

交叉验证是一种评估机器学习模型的技术,它通过在可用数据的不同子集上训练多个模型,并在互补子集上对其进行评估。这有助于避免过拟合,并确保模型具有泛化能力。在本教程中,我们将使用 scikit-learn 来探索不同的交叉验证技术及其行为。

虚拟机使用提示

虚拟机启动完成后,点击左上角切换到笔记本标签页,以访问 Jupyter Notebook 进行练习。

有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。

如果你在学习过程中遇到问题,请随时向 Labby 提问。课程结束后提供反馈,我们将立即为你解决问题。

可视化数据

在这一步中,我们将可视化我们要处理的数据。数据由 100 个随机生成的输入数据点、3 个在数据点中分布不均匀的类别以及 10 个在数据点中均匀分布的“组”组成。

## 生成类别/组数据
n_points = 100
X = rng.randn(100, 10)

percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])

## 生成不均匀的组
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))


def visualize_groups(classes, groups, name):
    ## 可视化数据集组
    fig, ax = plt.subplots()
    ax.scatter(
        range(len(groups)),
        [0.5] * len(groups),
        c=groups,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.scatter(
        range(len(groups)),
        [3.5] * len(groups),
        c=classes,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.set(
        ylim=[-1, 5],
        yticks=[0.5, 3.5],
        yticklabels=["数据\n组", "数据\n类别"],
        xlabel="样本索引"
    )


visualize_groups(y, groups, "无组")

可视化交叉验证索引

在这一步中,我们将定义一个函数来可视化每个交叉验证对象的行为。我们将对数据进行 4 次分割。在每次分割中,我们将可视化选择用于训练集(蓝色)和测试集(红色)的索引。

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    """为交叉验证对象的索引创建一个示例图。"""

    ## 为每个交叉验证分割生成训练/测试可视化
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        ## 用训练/测试组填充索引
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        ## 可视化结果
        ax.scatter(
            range(len(indices)),
            [ii + 0.5] * len(indices),
            c=indices,
            marker="_",
            lw=lw,
            cmap=cmap_cv,
            vmin=-0.2,
            vmax=1.2,
        )

    ## 最后绘制数据类别和组
    ax.scatter(
        range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
    )

    ax.scatter(
        range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
    )

    ## 格式化
    yticklabels = list(range(n_splits)) + ["类别", "组"]
    ax.set(
        yticks=np.arange(n_splits + 2) + 0.5,
        yticklabels=yticklabels,
        xlabel="样本索引",
        ylabel="交叉验证迭代",
        ylim=[n_splits + 2.2, -0.2],
        xlim=[0, 100],
    )
    ax.set_title("{}".format(type(cv).__name__), fontsize=15)
    return ax

比较交叉验证对象

在这一步中,我们将比较不同的 scikit-learn 交叉验证对象的交叉验证行为。我们将遍历几个常见的交叉验证对象,可视化每个对象的行为。注意有些对象如何使用组/类别信息,而有些则不使用。

cvs = [
    KFold,
    GroupKFold,
    ShuffleSplit,
    StratifiedKFold,
    StratifiedGroupKFold,
    GroupShuffleSplit,
    StratifiedShuffleSplit,
    TimeSeriesSplit,
]

for cv in cvs:
    this_cv = cv(n_splits=n_splits)
    fig, ax = plt.subplots(figsize=(6, 3))
    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend(
        [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
        ["测试集", "训练集"],
        loc=(1.02, 0.8),
    )
    ## 调整图例使其合适
    plt.tight_layout()
    fig.subplots_adjust(right=0.7)
plt.show()

总结

在本教程中,我们使用 scikit-learn 探索了不同的交叉验证技术及其行为。我们对数据进行了可视化,定义了一个函数来可视化交叉验证索引,并比较了不同 scikit-learn 交叉验证对象的交叉验证行为。