Introdução
A validação cruzada é uma técnica para avaliar modelos de aprendizado de máquina treinando vários modelos em diferentes subconjuntos dos dados disponíveis e avaliando-os no subconjunto complementar. Isso ajuda a evitar o superajuste e garante que o modelo seja generalizável. Neste tutorial, usaremos o scikit-learn para explorar diferentes técnicas de validação cruzada e seu comportamento.
Dicas da Máquina Virtual
Após o término da inicialização da máquina virtual, clique no canto superior esquerdo para mudar para a aba Notebook para acessar o Jupyter Notebook para praticar.
Às vezes, pode ser necessário aguardar alguns segundos para que o Jupyter Notebook termine de carregar. A validação das operações não pode ser automatizada devido a limitações no Jupyter Notebook.
Se você enfrentar problemas durante o aprendizado, sinta-se à vontade para perguntar ao Labby. Forneça feedback após a sessão e resolveremos prontamente o problema para você.
Visualizar os Dados
Nesta etapa, visualizaremos os dados com os quais trabalharemos. Os dados consistem em 100 pontos de dados de entrada gerados aleatoriamente, 3 classes divididas de forma desigual entre os pontos de dados e 10 "grupos" divididos igualmente entre os pontos de dados.
## Gerar os dados de classe/grupo
n_pontos = 100
X = rng.randn(100, 10)
percentis_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentis_classes)])
## Gerar grupos desiguais
prior_grupo = rng.dirichlet([2] * 10)
grupos = np.repeat(np.arange(10), rng.multinomial(100, prior_grupo))
def visualizar_grupos(classes, grupos, nome):
## Visualizar os grupos de dados
fig, ax = plt.subplots()
ax.scatter(
range(len(grupos)),
[0.5] * len(grupos),
c=grupos,
marker="_",
lw=50,
cmap=cmap_data,
)
ax.scatter(
range(len(grupos)),
[3.5] * len(grupos),
c=classes,
marker="_",
lw=50,
cmap=cmap_data,
)
ax.set(
ylim=[-1, 5],
yticks=[0.5, 3.5],
yticklabels=["Grupo de\ndados", "Classe\nde dados"],
xlabel="Índice da amostra",
)
visualizar_grupos(y, grupos, "sem grupos")
Visualizar Índices de Validação Cruzada
Nesta etapa, definiremos uma função para visualizar o comportamento de cada objeto de validação cruzada. Realizaremos 4 divisões dos dados. Em cada divisão, visualizaremos os índices escolhidos para o conjunto de treinamento (em azul) e o conjunto de teste (em vermelho).
def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
"""Criar um gráfico de amostra para índices de um objeto de validação cruzada."""
## Gerar as visualizações de treinamento/teste para cada divisão de validação cruzada
for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
## Preencher os índices com os grupos de treinamento/teste
indices = np.array([np.nan] * len(X))
indices[tt] = 1
indices[tr] = 0
## Visualizar os resultados
ax.scatter(
range(len(indices)),
[ii + 0.5] * len(indices),
c=indices,
marker="_",
lw=lw,
cmap=cmap_cv,
vmin=-0.2,
vmax=1.2,
)
## Plotar as classes e grupos de dados no final
ax.scatter(
range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
)
ax.scatter(
range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
)
## Formatação
yticklabels = list(range(n_splits)) + ["classe", "grupo"]
ax.set(
yticks=np.arange(n_splits + 2) + 0.5,
yticklabels=yticklabels,
xlabel="Índice da amostra",
ylabel="Iteração de validação cruzada",
ylim=[n_splits + 2.2, -0.2],
xlim=[0, 100],
)
ax.set_title("{}".format(type(cv).__name__), fontsize=15)
return ax
Comparar Objetos de Validação Cruzada
Nesta etapa, compararemos o comportamento de validação cruzada para diferentes objetos de validação cruzada do scikit-learn. Irá percorrer vários objetos de validação cruzada comuns, visualizando o comportamento de cada um. Note como alguns utilizam as informações de grupo/classe, enquanto outros não.
cvs = [
KFold,
GroupKFold,
ShuffleSplit,
StratifiedKFold,
StratifiedGroupKFold,
GroupShuffleSplit,
StratifiedShuffleSplit,
TimeSeriesSplit,
]
for cv in cvs:
this_cv = cv(n_splits=n_splits)
fig, ax = plt.subplots(figsize=(6, 3))
plot_cv_indices(this_cv, X, y, groups, ax, n_splits)
ax.legend(
[Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
["Conjunto de teste", "Conjunto de treinamento"],
loc=(1.02, 0.8),
)
## Ajustar o tamanho da legenda
plt.tight_layout()
fig.subplots_adjust(right=0.7)
plt.show()
Resumo
Neste tutorial, utilizamos o scikit-learn para explorar diferentes técnicas de validação cruzada e seu comportamento. Visualizamos os dados, definimos uma função para visualizar os índices de validação cruzada e comparamos o comportamento de validação cruzada para diferentes objetos de validação cruzada do scikit-learn.