Técnicas de Validação Cruzada com Scikit-Learn

Beginner

This tutorial is from open-source community. Access the source code

Introdução

A validação cruzada é uma técnica para avaliar modelos de aprendizado de máquina treinando vários modelos em diferentes subconjuntos dos dados disponíveis e avaliando-os no subconjunto complementar. Isso ajuda a evitar o superajuste e garante que o modelo seja generalizável. Neste tutorial, usaremos o scikit-learn para explorar diferentes técnicas de validação cruzada e seu comportamento.

Dicas da Máquina Virtual

Após o término da inicialização da máquina virtual, clique no canto superior esquerdo para mudar para a aba Notebook para acessar o Jupyter Notebook para praticar.

Às vezes, pode ser necessário aguardar alguns segundos para que o Jupyter Notebook termine de carregar. A validação das operações não pode ser automatizada devido a limitações no Jupyter Notebook.

Se você enfrentar problemas durante o aprendizado, sinta-se à vontade para perguntar ao Labby. Forneça feedback após a sessão e resolveremos prontamente o problema para você.

Visualizar os Dados

Nesta etapa, visualizaremos os dados com os quais trabalharemos. Os dados consistem em 100 pontos de dados de entrada gerados aleatoriamente, 3 classes divididas de forma desigual entre os pontos de dados e 10 "grupos" divididos igualmente entre os pontos de dados.

## Gerar os dados de classe/grupo
n_pontos = 100
X = rng.randn(100, 10)

percentis_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentis_classes)])

## Gerar grupos desiguais
prior_grupo = rng.dirichlet([2] * 10)
grupos = np.repeat(np.arange(10), rng.multinomial(100, prior_grupo))


def visualizar_grupos(classes, grupos, nome):
    ## Visualizar os grupos de dados
    fig, ax = plt.subplots()
    ax.scatter(
        range(len(grupos)),
        [0.5] * len(grupos),
        c=grupos,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.scatter(
        range(len(grupos)),
        [3.5] * len(grupos),
        c=classes,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.set(
        ylim=[-1, 5],
        yticks=[0.5, 3.5],
        yticklabels=["Grupo de\ndados", "Classe\nde dados"],
        xlabel="Índice da amostra",
    )


visualizar_grupos(y, grupos, "sem grupos")

Visualizar Índices de Validação Cruzada

Nesta etapa, definiremos uma função para visualizar o comportamento de cada objeto de validação cruzada. Realizaremos 4 divisões dos dados. Em cada divisão, visualizaremos os índices escolhidos para o conjunto de treinamento (em azul) e o conjunto de teste (em vermelho).

def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    """Criar um gráfico de amostra para índices de um objeto de validação cruzada."""

    ## Gerar as visualizações de treinamento/teste para cada divisão de validação cruzada
    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
        ## Preencher os índices com os grupos de treinamento/teste
        indices = np.array([np.nan] * len(X))
        indices[tt] = 1
        indices[tr] = 0

        ## Visualizar os resultados
        ax.scatter(
            range(len(indices)),
            [ii + 0.5] * len(indices),
            c=indices,
            marker="_",
            lw=lw,
            cmap=cmap_cv,
            vmin=-0.2,
            vmax=1.2,
        )

    ## Plotar as classes e grupos de dados no final
    ax.scatter(
        range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
    )

    ax.scatter(
        range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
    )

    ## Formatação
    yticklabels = list(range(n_splits)) + ["classe", "grupo"]
    ax.set(
        yticks=np.arange(n_splits + 2) + 0.5,
        yticklabels=yticklabels,
        xlabel="Índice da amostra",
        ylabel="Iteração de validação cruzada",
        ylim=[n_splits + 2.2, -0.2],
        xlim=[0, 100],
    )
    ax.set_title("{}".format(type(cv).__name__), fontsize=15)
    return ax

Comparar Objetos de Validação Cruzada

Nesta etapa, compararemos o comportamento de validação cruzada para diferentes objetos de validação cruzada do scikit-learn. Irá percorrer vários objetos de validação cruzada comuns, visualizando o comportamento de cada um. Note como alguns utilizam as informações de grupo/classe, enquanto outros não.

cvs = [
    KFold,
    GroupKFold,
    ShuffleSplit,
    StratifiedKFold,
    StratifiedGroupKFold,
    GroupShuffleSplit,
    StratifiedShuffleSplit,
    TimeSeriesSplit,
]

for cv in cvs:
    this_cv = cv(n_splits=n_splits)
    fig, ax = plt.subplots(figsize=(6, 3))
    plot_cv_indices(this_cv, X, y, groups, ax, n_splits)

    ax.legend(
        [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
        ["Conjunto de teste", "Conjunto de treinamento"],
        loc=(1.02, 0.8),
    )
    ## Ajustar o tamanho da legenda
    plt.tight_layout()
    fig.subplots_adjust(right=0.7)
plt.show()

Resumo

Neste tutorial, utilizamos o scikit-learn para explorar diferentes técnicas de validação cruzada e seu comportamento. Visualizamos os dados, definimos uma função para visualizar os índices de validação cruzada e comparamos o comportamento de validação cruzada para diferentes objetos de validação cruzada do scikit-learn.