Técnicas de validación cruzada con Scikit-Learn

Machine LearningMachine LearningBeginner
Practicar Ahora

This tutorial is from open-source community. Access the source code

💡 Este tutorial está traducido por IA desde la versión en inglés. Para ver la versión original, puedes hacer clic aquí

Introducción

La validación cruzada es una técnica para evaluar modelos de aprendizaje automático mediante el entrenamiento de varios modelos en diferentes subconjuntos de los datos disponibles y la evaluación de estos en el subconjunto complementario. Esto ayuda a evitar el sobreajuste y asegura que el modelo sea generalizable. En este tutorial, usaremos scikit-learn para explorar diferentes técnicas de validación cruzada y su comportamiento.

Consejos sobre la VM

Una vez finalizada la inicialización de la VM, haga clic en la esquina superior izquierda para cambiar a la pestaña Cuaderno y acceder a Jupyter Notebook para practicar.

A veces, es posible que tenga que esperar unos segundos a que Jupyter Notebook termine de cargarse. La validación de operaciones no se puede automatizar debido a las limitaciones de Jupyter Notebook.

Si tiene problemas durante el aprendizaje, no dude en preguntar a Labby. Deje su retroalimentación después de la sesión y le resolveremos el problema inmediatamente.


Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL ml(("Machine Learning")) -.-> ml/FrameworkandSoftwareGroup(["Framework and Software"]) ml/FrameworkandSoftwareGroup -.-> ml/sklearn("scikit-learn") subgraph Lab Skills ml/sklearn -.-> lab-49100{{"Técnicas de validación cruzada con Scikit-Learn"}} end

Visualizar los datos

En este paso, visualizaremos los datos con los que trabajaremos. Los datos constan de 100 puntos de datos de entrada generados aleatoriamente, 3 clases divididas de manera desigual entre los puntos de datos y 10 "grupos" divididos equitativamente entre los puntos de datos.

## Generate the class/group data
n_points = 100
X = rng.randn(100, 10)

percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])

## Generate uneven groups
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))


def visualize_groups(classes, groups, name):
    ## Visualize dataset groups
    fig, ax = plt.subplots()
    ax.scatter(
        range(len(groups)),
        [0.5] * len(groups),
        c=groups,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.scatter(
        range(len(groups)),
        [3.5] * len(groups),
        c=classes,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.set(
        ylim=[-1, 5],
        yticks=[0.5, 3.5],
        yticklabels=["Data\ngroup", "Data\nclass"],
        xlabel="Sample index",
    )


visualize_groups(y, groups, "no groups")

Visualizar los datos

En este paso, visualizaremos los datos con los que trabajaremos. Los datos constan de 100 puntos de datos de entrada generados aleatoriamente, 3 clases divididas de manera desigual entre los puntos de datos y 10 "grupos" divididos equitativamente entre los puntos de datos.

## Generar los datos de clase/grupo
n_points = 100
X = rng.randn(100, 10)

percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])

## Generar grupos desiguales
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))


def visualizar_grupos(classes, groups, name):
    ## Visualizar los grupos del conjunto de datos
    fig, ax = plt.subplots()
    ax.scatter(
        range(len(groups)),
        [0.5] * len(groups),
        c=groups,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.scatter(
        range(len(groups)),
        [3.5] * len(groups),
        c=classes,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.set(
        ylim=[-1, 5],
        yticks=[0.5, 3.5],
        yticklabels=["Datos\ngrupo", "Datos\nclase"],
        xlabel="Índice de muestra",
    )


visualizar_grupos(y, groups, "sin grupos")

Visualizar los índices de validación cruzada

En este paso, definiremos una función para visualizar el comportamiento de cada objeto de validación cruzada. Realizaremos 4 divisiones del data. En cada división, visualizaremos los índices elegidos para el conjunto de entrenamiento (en azul) y el conjunto de prueba (en rojo).

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    """Create a sample plot for indices of a cross-validation object."""

    ## Generate the training/testing visualizations for each CV split
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        ## Fill in indices with the training/test groups
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        ## Visualize the results
        ax.scatter(
            range(len(indices)),
            [ii + 0.5] * len(indices),
            c=indices,
            marker="_",
            lw=lw,
            cmap=cmap_cv,
            vmin=-0.2,
            vmax=1.2,
        )

    ## Plot the data classes and groups at the end
    ax.scatter(
        range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
    )

    ax.scatter(
        range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
    )

    ## Formatting
    yticklabels = list(range(n_splits)) + ["class", "group"]
    ax.set(
        yticks=np.arange(n_splits + 2) + 0.5,
        yticklabels=yticklabels,
        xlabel="Índice de muestra",
        ylabel="Iteración de validación cruzada",
        ylim=[n_splits + 2.2, -0.2],
        xlim=[0, 100],
    )
    ax.set_title("{}".format(type(cv).__name__), fontsize=15)
    return ax

Visualizar los índices de validación cruzada

En este paso, definiremos una función para visualizar el comportamiento de cada objeto de validación cruzada. Realizaremos 4 divisiones del data. En cada división, visualizaremos los índices elegidos para el conjunto de entrenamiento (en azul) y el conjunto de prueba (en rojo).

def graficar_indices_cv(cv, X, y, group, ax, n_splits, lw=10):
    """Crear una gráfica de muestra para los índices de un objeto de validación cruzada."""

    ## Generar las visualizaciones de entrenamiento/prueba para cada división de validación cruzada
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        ## Llenar los índices con los grupos de entrenamiento/prueba
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        ## Visualizar los resultados
        ax.scatter(
            range(len(indices)),
            [ii + 0.5] * len(indices),
            c=indices,
            marker="_",
            lw=lw,
            cmap=cmap_cv,
            vmin=-0.2,
            vmax=1.2,
        )

    ## Graficar las clases y grupos de datos al final
    ax.scatter(
        range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
    )

    ax.scatter(
        range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
    )

    ## Formateo
    yticklabels = list(range(n_splits)) + ["clase", "grupo"]
    ax.set(
        yticks=np.arange(n_splits + 2) + 0.5,
        yticklabels=yticklabels,
        xlabel="Índice de muestra",
        ylabel="Iteración de validación cruzada",
        ylim=[n_splits + 2.2, -0.2],
        xlim=[0, 100],
    )
    ax.set_title("{}".format(type(cv).__name__), fontsize=15)
    return ax

Comparar objetos de validación cruzada

En este paso, compararemos el comportamiento de la validación cruzada para diferentes objetos de validación cruzada de scikit-learn. Recorreremos varios objetos comunes de validación cruzada, visualizando el comportamiento de cada uno. Observe cómo algunos utilizan la información de grupo/clase mientras que otros no.

cvs = [
    KFold,
    GroupKFold,
    ShuffleSplit,
    StratifiedKFold,
    StratifiedGroupKFold,
    GroupShuffleSplit,
    StratifiedShuffleSplit,
    TimeSeriesSplit,
]

for cv in cvs:
    this_cv = cv(n_splits=n_splits)
    fig, ax = plt.subplots(figsize=(6, 3))
    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend(
        [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
        ["Conjunto de prueba", "Conjunto de entrenamiento"],
        loc=(1.02, 0.8),
    )
    ## Hacer que la leyenda quepa
    plt.tight_layout()
    fig.subplots_adjust(right=0.7)
plt.show()

Comparar objetos de validación cruzada

En este paso, compararemos el comportamiento de la validación cruzada para diferentes objetos de validación cruzada de scikit-learn. Recorreremos varios objetos comunes de validación cruzada, visualizando el comportamiento de cada uno. Observe cómo algunos utilizan la información de grupo/clase mientras que otros no.

cvs = [
    KFold,
    GroupKFold,
    ShuffleSplit,
    StratifiedKFold,
    StratifiedGroupKFold,
    GroupShuffleSplit,
    StratifiedShuffleSplit,
    TimeSeriesSplit,
]

Para cada cv en cvs:
    this_cv = cv(n_splits=n_splits)
    fig, ax = plt.subplots(figsize=(6, 3))
    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend(
        [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
        ["Conjunto de prueba", "Conjunto de entrenamiento"],
        loc=(1.02, 0.8),
    )
    ## Hacer que la leyenda quepa
    plt.tight_layout()
    fig.subplots_adjust(right=0.7)
plt.show()

Resumen

En este tutorial, usamos scikit-learn para explorar diferentes técnicas de validación cruzada y su comportamiento. Visualizamos los datos, definimos una función para visualizar los índices de validación cruzada y comparamos el comportamiento de la validación cruzada para diferentes objetos de validación cruzada de scikit-learn.