简介
交叉验证是一种评估机器学习模型的技术,它通过在可用数据的不同子集上训练多个模型,并在互补子集上对其进行评估。这有助于避免过拟合,并确保模型具有泛化能力。在本教程中,我们将使用 scikit-learn 来探索不同的交叉验证技术及其行为。
虚拟机使用提示
虚拟机启动完成后,点击左上角切换到笔记本标签页,以访问 Jupyter Notebook 进行练习。
有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。
如果你在学习过程中遇到问题,请随时向 Labby 提问。课程结束后提供反馈,我们将立即为你解决问题。
可视化数据
在这一步中,我们将可视化我们要处理的数据。数据由 100 个随机生成的输入数据点、3 个在数据点中分布不均匀的类别以及 10 个在数据点中均匀分布的“组”组成。
## 生成类别/组数据
n_points = 100
X = rng.randn(100, 10)
percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])
## 生成不均匀的组
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))
def visualize_groups(classes, groups, name):
## 可视化数据集组
fig, ax = plt.subplots()
ax.scatter(
range(len(groups)),
[0.5] * len(groups),
c=groups,
marker="_",
lw=50,
cmap=cmap_data,
)
ax.scatter(
range(len(groups)),
[3.5] * len(groups),
c=classes,
marker="_",
lw=50,
cmap=cmap_data,
)
ax.set(
ylim=[-1, 5],
yticks=[0.5, 3.5],
yticklabels=["数据\n组", "数据\n类别"],
xlabel="样本索引"
)
visualize_groups(y, groups, "无组")
可视化交叉验证索引
在这一步中,我们将定义一个函数来可视化每个交叉验证对象的行为。我们将对数据进行 4 次分割。在每次分割中,我们将可视化选择用于训练集(蓝色)和测试集(红色)的索引。
def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
"""为交叉验证对象的索引创建一个示例图。"""
## 为每个交叉验证分割生成训练/测试可视化
for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
## 用训练/测试组填充索引
indices = np.array([np.nan] * len(X))
indices[tt] = 1
indices[tr] = 0
## 可视化结果
ax.scatter(
range(len(indices)),
[ii + 0.5] * len(indices),
c=indices,
marker="_",
lw=lw,
cmap=cmap_cv,
vmin=-0.2,
vmax=1.2,
)
## 最后绘制数据类别和组
ax.scatter(
range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
)
ax.scatter(
range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
)
## 格式化
yticklabels = list(range(n_splits)) + ["类别", "组"]
ax.set(
yticks=np.arange(n_splits + 2) + 0.5,
yticklabels=yticklabels,
xlabel="样本索引",
ylabel="交叉验证迭代",
ylim=[n_splits + 2.2, -0.2],
xlim=[0, 100],
)
ax.set_title("{}".format(type(cv).__name__), fontsize=15)
return ax
比较交叉验证对象
在这一步中,我们将比较不同的 scikit-learn 交叉验证对象的交叉验证行为。我们将遍历几个常见的交叉验证对象,可视化每个对象的行为。注意有些对象如何使用组/类别信息,而有些则不使用。
cvs = [
KFold,
GroupKFold,
ShuffleSplit,
StratifiedKFold,
StratifiedGroupKFold,
GroupShuffleSplit,
StratifiedShuffleSplit,
TimeSeriesSplit,
]
for cv in cvs:
this_cv = cv(n_splits=n_splits)
fig, ax = plt.subplots(figsize=(6, 3))
plot_cv_indices(this_cv, X, y, groups, ax, n_splits)
ax.legend(
[Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
["测试集", "训练集"],
loc=(1.02, 0.8),
)
## 调整图例使其合适
plt.tight_layout()
fig.subplots_adjust(right=0.7)
plt.show()
总结
在本教程中,我们使用 scikit-learn 探索了不同的交叉验证技术及其行为。我们对数据进行了可视化,定义了一个函数来可视化交叉验证索引,并比较了不同 scikit-learn 交叉验证对象的交叉验证行为。