Техники кросс-валидации с использованием Scikit-Learn

Beginner

This tutorial is from open-source community. Access the source code

Введение

Кросс-валидация - это метод оценки моделей машинного обучения путём обучения нескольких моделей на различных подмножествах доступных данных и их оценки на дополняющемся подмножестве. Это помогает избежать переобучения и обеспечивает обобщаемость модели. В этом руководстве мы будем использовать scikit-learn для исследования различных методов кросс-валидации и их поведения.

Советы по работе с ВМ

После запуска ВМ нажмите в левом верхнем углу, чтобы переключиться на вкладку Ноутбук и получить доступ к Jupyter Notebook для практики.

Иногда вам может потребоваться подождать несколько секунд, пока Jupyter Notebook загрузится. Валидация операций не может быть автоматизирована из-за ограничений Jupyter Notebook.

Если вы сталкиваетесь с проблемами во время обучения, не стесняйтесь обращаться к Labby. Оставьте отзыв после занятия, и мы оперативно решим проблему для вас.

Визуализация данных

В этом шаге мы визуализируем данные, с которыми будем работать. Данные состоят из 100 случайно сгенерированных входных данных, 3 класса, распределенных неравномерно по данным, и 10 "групп", распределенных равномерно по данным.

## Generate the class/group data
n_points = 100
X = rng.randn(100, 10)

percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])

## Generate uneven groups
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))


def visualize_groups(classes, groups, name):
    ## Visualize dataset groups
    fig, ax = plt.subplots()
    ax.scatter(
        range(len(groups)),
        [0.5] * len(groups),
        c=groups,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.scatter(
        range(len(groups)),
        [3.5] * len(groups),
        c=classes,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.set(
        ylim=[-1, 5],
        yticks=[0.5, 3.5],
        yticklabels=["Data\ngroup", "Data\nclass"],
        xlabel="Sample index",
    )


visualize_groups(y, groups, "no groups")

Визуализация данных

В этом шаге мы визуализируем данные, с которыми будем работать. Данные состоят из 100 случайно сгенерированных входных данных, 3 класса, распределенных неравномерно по данным, и 10 "групп", распределенных равномерно по данным.

## Generate the class/group data
n_points = 100
X = rng.randn(100, 10)

percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])

## Generate uneven groups
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))


def visualize_groups(classes, groups, name):
    ## Visualize dataset groups
    fig, ax = plt.subplots()
    ax.scatter(
        range(len(groups)),
        [0.5] * len(groups),
        c=groups,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.scatter(
        range(len(groups)),
        [3.5] * len(groups),
        c=classes,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.set(
        ylim=[-1, 5],
        yticks=[0.5, 3.5],
        yticklabels=["Data\ngroup", "Data\nclass"],
        xlabel="Sample index",
    )


visualize_groups(y, groups, "no groups")

Визуализация индексов кросс-валидации

В этом шаге мы определим функцию для визуализации поведения каждого объекта кросс-валидации. Мы будем выполнять 4 разбиения данных. На каждом разбиении мы визуализируем индексы, выбранные для обучающего набора (синим цветом) и тестового набора (красным цветом).

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    """Create a sample plot for indices of a cross-validation object."""

    ## Generate the training/testing visualizations for each CV split
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        ## Fill in indices with the training/test groups
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        ## Visualize the results
        ax.scatter(
            range(len(indices)),
            [ii + 0.5] * len(indices),
            c=indices,
            marker="_",
            lw=lw,
            cmap=cmap_cv,
            vmin=-0.2,
            vmax=1.2,
        )

    ## Plot the data classes and groups at the end
    ax.scatter(
        range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
    )

    ax.scatter(
        range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
    )

    ## Formatting
    yticklabels = list(range(n_splits)) + ["class", "group"]
    ax.set(
        yticks=np.arange(n_splits + 2) + 0.5,
        yticklabels=yticklabels,
        xlabel="Sample index",
        ylabel="CV iteration",
        ylim=[n_splits + 2.2, -0.2],
        xlim=[0, 100],
    )
    ax.set_title("{}".format(type(cv).__name__), fontsize=15)
    return ax

Визуализация индексов кросс-валидации

В этом шаге мы определим функцию для визуализации поведения каждого объекта кросс-валидации. Мы будем выполнять 4 разбиения данных. На каждом разбиении мы визуализируем индексы, выбранные для обучающего набора (синим цветом) и тестового набора (красным цветом).

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    """Create a sample plot for indices of a cross-validation object."""

    ## Generate the training/testing visualizations for each CV split
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        ## Fill in indices with the training/test groups
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        ## Visualize the results
        ax.scatter(
            range(len(indices)),
            [ii + 0.5] * len(indices),
            c=indices,
            marker="_",
            lw=lw,
            cmap=cmap_cv,
            vmin=-0.2,
            vmax=1.2,
        )

    ## Plot the data classes and groups at the end
    ax.scatter(
        range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
    )

    ax.scatter(
        range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
    )

    ## Formatting
    yticklabels = list(range(n_splits)) + ["class", "group"]
    ax.set(
        yticks=np.arange(n_splits + 2) + 0.5,
        yticklabels=yticklabels,
        xlabel="Sample index",
        ylabel="CV iteration",
        ylim=[n_splits + 2.2, -0.2],
        xlim=[0, 100],
    )
    ax.set_title("{}".format(type(cv).__name__), fontsize=15)
    return ax

Сравнение объектов кросс-валидации

В этом шаге мы сравним поведение кросс-валидации для различных объектов кросс-валидации в scikit-learn. Мы пройдемся по нескольким общим объектам кросс-валидации, визуализируя поведение каждого. Обратите внимание, как некоторые используют информацию о группе/классе, а другие нет.

cvs = [
    KFold,
    GroupKFold,
    ShuffleSplit,
    StratifiedKFold,
    StratifiedGroupKFold,
    GroupShuffleSplit,
    StratifiedShuffleSplit,
    TimeSeriesSplit,
]

for cv in cvs:
    this_cv = cv(n_splits=n_splits)
    fig, ax = plt.subplots(figsize=(6, 3))
    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend(
        [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
        ["Testing set", "Training set"],
        loc=(1.02, 0.8),
    )
    ## Make the legend fit
    plt.tight_layout()
    fig.subplots_adjust(right=0.7)
plt.show()

Сравнение объектов кросс-валидации

В этом шаге мы сравним поведение кросс-валидации для различных объектов кросс-валидации в scikit-learn. Мы пройдемся по нескольким общим объектам кросс-валидации, визуализируя поведение каждого. Обратите внимание, как некоторые используют информацию о группе/классе, а другие нет.

cvs = [
    KFold,
    GroupKFold,
    ShuffleSplit,
    StratifiedKFold,
    StratifiedGroupKFold,
    GroupShuffleSplit,
    StratifiedShuffleSplit,
    TimeSeriesSplit,
]

for cv in cvs:
    this_cv = cv(n_splits=n_splits)
    fig, ax = plt.subplots(figsize=(6, 3))
    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend(
        [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
        ["Testing set", "Training set"],
        loc=(1.02, 0.8),
    )
    ## Make the legend fit
    plt.tight_layout()
    fig.subplots_adjust(right=0.7)
plt.show()

Резюме

В этом руководстве мы использовали scikit-learn для исследования различных методов кросс-валидации и их поведения. Мы визуализировали данные, определили функцию для визуализации индексов кросс-валидации и сравнили поведение кросс-валидации для различных объектов кросс-валидации в scikit-learn.