Introducción
La validación cruzada es una técnica para evaluar modelos de aprendizaje automático mediante el entrenamiento de varios modelos en diferentes subconjuntos de los datos disponibles y la evaluación de estos en el subconjunto complementario. Esto ayuda a evitar el sobreajuste y asegura que el modelo sea generalizable. En este tutorial, usaremos scikit-learn para explorar diferentes técnicas de validación cruzada y su comportamiento.
Consejos sobre la VM
Una vez finalizada la inicialización de la VM, haga clic en la esquina superior izquierda para cambiar a la pestaña Cuaderno y acceder a Jupyter Notebook para practicar.
A veces, es posible que tenga que esperar unos segundos a que Jupyter Notebook termine de cargarse. La validación de operaciones no se puede automatizar debido a las limitaciones de Jupyter Notebook.
Si tiene problemas durante el aprendizaje, no dude en preguntar a Labby. Deje su retroalimentación después de la sesión y le resolveremos el problema inmediatamente.
Visualizar los datos
En este paso, visualizaremos los datos con los que trabajaremos. Los datos constan de 100 puntos de datos de entrada generados aleatoriamente, 3 clases divididas de manera desigual entre los puntos de datos y 10 "grupos" divididos equitativamente entre los puntos de datos.
## Generate the class/group data
n_points = 100
X = rng.randn(100, 10)
percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])
## Generate uneven groups
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))
def visualize_groups(classes, groups, name):
## Visualize dataset groups
fig, ax = plt.subplots()
ax.scatter(
range(len(groups)),
[0.5] * len(groups),
c=groups,
marker="_",
lw=50,
cmap=cmap_data,
)
ax.scatter(
range(len(groups)),
[3.5] * len(groups),
c=classes,
marker="_",
lw=50,
cmap=cmap_data,
)
ax.set(
ylim=[-1, 5],
yticks=[0.5, 3.5],
yticklabels=["Data\ngroup", "Data\nclass"],
xlabel="Sample index",
)
visualize_groups(y, groups, "no groups")
Visualizar los datos
En este paso, visualizaremos los datos con los que trabajaremos. Los datos constan de 100 puntos de datos de entrada generados aleatoriamente, 3 clases divididas de manera desigual entre los puntos de datos y 10 "grupos" divididos equitativamente entre los puntos de datos.
## Generar los datos de clase/grupo
n_points = 100
X = rng.randn(100, 10)
percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])
## Generar grupos desiguales
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))
def visualizar_grupos(classes, groups, name):
## Visualizar los grupos del conjunto de datos
fig, ax = plt.subplots()
ax.scatter(
range(len(groups)),
[0.5] * len(groups),
c=groups,
marker="_",
lw=50,
cmap=cmap_data,
)
ax.scatter(
range(len(groups)),
[3.5] * len(groups),
c=classes,
marker="_",
lw=50,
cmap=cmap_data,
)
ax.set(
ylim=[-1, 5],
yticks=[0.5, 3.5],
yticklabels=["Datos\ngrupo", "Datos\nclase"],
xlabel="Índice de muestra",
)
visualizar_grupos(y, groups, "sin grupos")
Visualizar los índices de validación cruzada
En este paso, definiremos una función para visualizar el comportamiento de cada objeto de validación cruzada. Realizaremos 4 divisiones del data. En cada división, visualizaremos los índices elegidos para el conjunto de entrenamiento (en azul) y el conjunto de prueba (en rojo).
def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
"""Create a sample plot for indices of a cross-validation object."""
## Generate the training/testing visualizations for each CV split
for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
## Fill in indices with the training/test groups
indices = np.array([np.nan] * len(X))
indices[tt] = 1
indices[tr] = 0
## Visualize the results
ax.scatter(
range(len(indices)),
[ii + 0.5] * len(indices),
c=indices,
marker="_",
lw=lw,
cmap=cmap_cv,
vmin=-0.2,
vmax=1.2,
)
## Plot the data classes and groups at the end
ax.scatter(
range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
)
ax.scatter(
range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
)
## Formatting
yticklabels = list(range(n_splits)) + ["class", "group"]
ax.set(
yticks=np.arange(n_splits + 2) + 0.5,
yticklabels=yticklabels,
xlabel="Índice de muestra",
ylabel="Iteración de validación cruzada",
ylim=[n_splits + 2.2, -0.2],
xlim=[0, 100],
)
ax.set_title("{}".format(type(cv).__name__), fontsize=15)
return ax
Visualizar los índices de validación cruzada
En este paso, definiremos una función para visualizar el comportamiento de cada objeto de validación cruzada. Realizaremos 4 divisiones del data. En cada división, visualizaremos los índices elegidos para el conjunto de entrenamiento (en azul) y el conjunto de prueba (en rojo).
def graficar_indices_cv(cv, X, y, group, ax, n_splits, lw=10):
"""Crear una gráfica de muestra para los índices de un objeto de validación cruzada."""
## Generar las visualizaciones de entrenamiento/prueba para cada división de validación cruzada
for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
## Llenar los índices con los grupos de entrenamiento/prueba
indices = np.array([np.nan] * len(X))
indices[tt] = 1
indices[tr] = 0
## Visualizar los resultados
ax.scatter(
range(len(indices)),
[ii + 0.5] * len(indices),
c=indices,
marker="_",
lw=lw,
cmap=cmap_cv,
vmin=-0.2,
vmax=1.2,
)
## Graficar las clases y grupos de datos al final
ax.scatter(
range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
)
ax.scatter(
range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
)
## Formateo
yticklabels = list(range(n_splits)) + ["clase", "grupo"]
ax.set(
yticks=np.arange(n_splits + 2) + 0.5,
yticklabels=yticklabels,
xlabel="Índice de muestra",
ylabel="Iteración de validación cruzada",
ylim=[n_splits + 2.2, -0.2],
xlim=[0, 100],
)
ax.set_title("{}".format(type(cv).__name__), fontsize=15)
return ax
Comparar objetos de validación cruzada
En este paso, compararemos el comportamiento de la validación cruzada para diferentes objetos de validación cruzada de scikit-learn. Recorreremos varios objetos comunes de validación cruzada, visualizando el comportamiento de cada uno. Observe cómo algunos utilizan la información de grupo/clase mientras que otros no.
cvs = [
KFold,
GroupKFold,
ShuffleSplit,
StratifiedKFold,
StratifiedGroupKFold,
GroupShuffleSplit,
StratifiedShuffleSplit,
TimeSeriesSplit,
]
for cv in cvs:
this_cv = cv(n_splits=n_splits)
fig, ax = plt.subplots(figsize=(6, 3))
plot_cv_indices(this_cv, X, y, groups, ax, n_splits)
ax.legend(
[Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
["Conjunto de prueba", "Conjunto de entrenamiento"],
loc=(1.02, 0.8),
)
## Hacer que la leyenda quepa
plt.tight_layout()
fig.subplots_adjust(right=0.7)
plt.show()
Comparar objetos de validación cruzada
En este paso, compararemos el comportamiento de la validación cruzada para diferentes objetos de validación cruzada de scikit-learn. Recorreremos varios objetos comunes de validación cruzada, visualizando el comportamiento de cada uno. Observe cómo algunos utilizan la información de grupo/clase mientras que otros no.
cvs = [
KFold,
GroupKFold,
ShuffleSplit,
StratifiedKFold,
StratifiedGroupKFold,
GroupShuffleSplit,
StratifiedShuffleSplit,
TimeSeriesSplit,
]
Para cada cv en cvs:
this_cv = cv(n_splits=n_splits)
fig, ax = plt.subplots(figsize=(6, 3))
plot_cv_indices(this_cv, X, y, groups, ax, n_splits)
ax.legend(
[Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
["Conjunto de prueba", "Conjunto de entrenamiento"],
loc=(1.02, 0.8),
)
## Hacer que la leyenda quepa
plt.tight_layout()
fig.subplots_adjust(right=0.7)
plt.show()
Resumen
En este tutorial, usamos scikit-learn para explorar diferentes técnicas de validación cruzada y su comportamiento. Visualizamos los datos, definimos una función para visualizar los índices de validación cruzada y comparamos el comportamiento de la validación cruzada para diferentes objetos de validación cruzada de scikit-learn.