Cross-Validation Techniques with Scikit-Learn

Machine LearningMachine LearningBeginner
Practice Now

This tutorial is from open-source community. Access the source code

Introduction

Cross-validation is a technique for evaluating machine learning models by training several models on different subsets of the available data and evaluating them on the complementary subset. This helps to avoid overfitting and ensures that the model is generalizable. In this tutorial, we will use scikit-learn to explore different cross-validation techniques and their behavior.

VM Tips

After the VM startup is done, click the top left corner to switch to the Notebook tab to access Jupyter Notebook for practice.

Sometimes, you may need to wait a few seconds for Jupyter Notebook to finish loading. The validation of operations cannot be automated because of limitations in Jupyter Notebook.

If you face issues during learning, feel free to ask Labby. Provide feedback after the session, and we will promptly resolve the problem for you.


Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL ml(("`Machine Learning`")) -.-> ml/FrameworkandSoftwareGroup(["`Framework and Software`"]) ml/FrameworkandSoftwareGroup -.-> ml/sklearn("`scikit-learn`") subgraph Lab Skills ml/sklearn -.-> lab-49100{{"`Cross-Validation Techniques with Scikit-Learn`"}} end

Visualize the Data

In this step, we will visualize the data we will be working with. The data consists of 100 randomly generated input datapoints, 3 classes split unevenly across datapoints, and 10 "groups" split evenly across datapoints.

## Generate the class/group data
n_points = 100
X = rng.randn(100, 10)

percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])

## Generate uneven groups
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))


def visualize_groups(classes, groups, name):
    ## Visualize dataset groups
    fig, ax = plt.subplots()
    ax.scatter(
        range(len(groups)),
        [0.5] * len(groups),
        c=groups,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.scatter(
        range(len(groups)),
        [3.5] * len(groups),
        c=classes,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.set(
        ylim=[-1, 5],
        yticks=[0.5, 3.5],
        yticklabels=["Data\ngroup", "Data\nclass"],
        xlabel="Sample index",
    )


visualize_groups(y, groups, "no groups")

Visualize Cross-Validation Indices

In this step, we will define a function to visualize the behavior of each cross-validation object. We will perform 4 splits of the data. On each split, we will visualize the indices chosen for the training set (in blue) and the test set (in red).

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    """Create a sample plot for indices of a cross-validation object."""

    ## Generate the training/testing visualizations for each CV split
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        ## Fill in indices with the training/test groups
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        ## Visualize the results
        ax.scatter(
            range(len(indices)),
            [ii + 0.5] * len(indices),
            c=indices,
            marker="_",
            lw=lw,
            cmap=cmap_cv,
            vmin=-0.2,
            vmax=1.2,
        )

    ## Plot the data classes and groups at the end
    ax.scatter(
        range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
    )

    ax.scatter(
        range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
    )

    ## Formatting
    yticklabels = list(range(n_splits)) + ["class", "group"]
    ax.set(
        yticks=np.arange(n_splits + 2) + 0.5,
        yticklabels=yticklabels,
        xlabel="Sample index",
        ylabel="CV iteration",
        ylim=[n_splits + 2.2, -0.2],
        xlim=[0, 100],
    )
    ax.set_title("{}".format(type(cv).__name__), fontsize=15)
    return ax

Compare Cross-Validation Objects

In this step, we will compare the cross-validation behavior for different scikit-learn cross-validation objects. We will loop through several common cross-validation objects, visualizing the behavior of each. Note how some use the group/class information while others do not.

cvs = [
    KFold,
    GroupKFold,
    ShuffleSplit,
    StratifiedKFold,
    StratifiedGroupKFold,
    GroupShuffleSplit,
    StratifiedShuffleSplit,
    TimeSeriesSplit,
]

for cv in cvs:
    this_cv = cv(n_splits=n_splits)
    fig, ax = plt.subplots(figsize=(6, 3))
    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend(
        [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
        ["Testing set", "Training set"],
        loc=(1.02, 0.8),
    )
    ## Make the legend fit
    plt.tight_layout()
    fig.subplots_adjust(right=0.7)
plt.show()

Summary

In this tutorial, we used scikit-learn to explore different cross-validation techniques and their behavior. We visualized the data, defined a function to visualize cross-validation indices, and compared cross-validation behavior for different scikit-learn cross-validation objects.

Other Machine Learning Tutorials you may like