Validação Cruzada em Conjunto de Dados de Dígitos

Beginner

This tutorial is from open-source community. Access the source code

Introdução

Este laboratório utiliza validação cruzada com uma máquina de vetores de suporte (SVM) no conjunto de dados de dígitos. Este é um problema de classificação, onde a tarefa é identificar dígitos a partir de imagens de dígitos manuscritos.

Dicas da Máquina Virtual

Após o arranque da máquina virtual, clique no canto superior esquerdo para mudar para a aba Notebook para aceder ao Jupyter Notebook para prática.

Por vezes, pode ser necessário esperar alguns segundos para o Jupyter Notebook terminar de carregar. A validação de operações não pode ser automatizada devido a limitações no Jupyter Notebook.

Se tiver problemas durante a aprendizagem, não hesite em contactar o Labby. Forneça feedback após a sessão e resolveremos prontamente o problema para si.

Carregar o conjunto de dados

Primeiro, precisamos carregar o conjunto de dados de dígitos do scikit-learn e dividi-lo em características e rótulos.

import numpy as np
from sklearn import datasets

X, y = datasets.load_digits(return_X_y=True)

Criar um modelo de Máquina de Vetores de Suporte (SVM)

Em seguida, criamos um modelo SVM com um kernel linear.

from sklearn import svm

svc = svm.SVC(kernel="linear")

Definir os valores dos hiperparâmetros a testar

Vamos testar diferentes valores do parâmetro de regularização C, que controla o equilíbrio entre maximizar a margem e minimizar o erro de classificação. Vamos testar 10 valores espaçados logaritmicamente entre 10^-10 e 1.

C_s = np.logspace(-10, 0, 10)

Executar validação cruzada e registrar resultados

Para cada valor de C, realizamos uma validação cruzada de 10 dobras e registramos a média e o desvio padrão das pontuações.

from sklearn.model_selection import cross_val_score

scores = list()
scores_std = list()
for C in C_s:
    svc.C = C
    this_scores = cross_val_score(svc, X, y, n_jobs=1)
    scores.append(np.mean(this_scores))
    scores_std.append(np.std(this_scores))

Plotar os resultados

Finalmente, plotamos as pontuações médias em função de C, e também incluímos barras de erro para visualizar o desvio padrão.

import matplotlib.pyplot as plt

plt.figure()
plt.semilogx(C_s, scores)
plt.semilogx(C_s, np.array(scores) + np.array(scores_std), "b--")
plt.semilogx(C_s, np.array(scores) - np.array(scores_std), "b--")
locs, labels = plt.yticks()
plt.yticks(locs, list(map(lambda x: "%g" % x, locs)))
plt.ylabel("Pontuação CV")
plt.xlabel("Parâmetro C")
plt.ylim(0, 1.1)
plt.show()

Resumo

Neste laboratório, realizámos validação cruzada de 10 dobras com um modelo SVM no conjunto de dados de dígitos, testando diferentes valores do parâmetro de regularização C. Plotamos os resultados para visualizar a relação entre C e a pontuação média de validação cruzada. Esta é uma técnica útil para ajustar hiperparâmetros e avaliar o desempenho do modelo.