Introdução
Este laboratório utiliza validação cruzada com uma máquina de vetores de suporte (SVM) no conjunto de dados de dígitos. Este é um problema de classificação, onde a tarefa é identificar dígitos a partir de imagens de dígitos manuscritos.
Dicas da Máquina Virtual
Após o arranque da máquina virtual, clique no canto superior esquerdo para mudar para a aba Notebook para aceder ao Jupyter Notebook para prática.
Por vezes, pode ser necessário esperar alguns segundos para o Jupyter Notebook terminar de carregar. A validação de operações não pode ser automatizada devido a limitações no Jupyter Notebook.
Se tiver problemas durante a aprendizagem, não hesite em contactar o Labby. Forneça feedback após a sessão e resolveremos prontamente o problema para si.
Carregar o conjunto de dados
Primeiro, precisamos carregar o conjunto de dados de dígitos do scikit-learn e dividi-lo em características e rótulos.
import numpy as np
from sklearn import datasets
X, y = datasets.load_digits(return_X_y=True)
Criar um modelo de Máquina de Vetores de Suporte (SVM)
Em seguida, criamos um modelo SVM com um kernel linear.
from sklearn import svm
svc = svm.SVC(kernel="linear")
Definir os valores dos hiperparâmetros a testar
Vamos testar diferentes valores do parâmetro de regularização C, que controla o equilíbrio entre maximizar a margem e minimizar o erro de classificação. Vamos testar 10 valores espaçados logaritmicamente entre 10^-10 e 1.
C_s = np.logspace(-10, 0, 10)
Executar validação cruzada e registrar resultados
Para cada valor de C, realizamos uma validação cruzada de 10 dobras e registramos a média e o desvio padrão das pontuações.
from sklearn.model_selection import cross_val_score
scores = list()
scores_std = list()
for C in C_s:
svc.C = C
this_scores = cross_val_score(svc, X, y, n_jobs=1)
scores.append(np.mean(this_scores))
scores_std.append(np.std(this_scores))
Plotar os resultados
Finalmente, plotamos as pontuações médias em função de C, e também incluímos barras de erro para visualizar o desvio padrão.
import matplotlib.pyplot as plt
plt.figure()
plt.semilogx(C_s, scores)
plt.semilogx(C_s, np.array(scores) + np.array(scores_std), "b--")
plt.semilogx(C_s, np.array(scores) - np.array(scores_std), "b--")
locs, labels = plt.yticks()
plt.yticks(locs, list(map(lambda x: "%g" % x, locs)))
plt.ylabel("Pontuação CV")
plt.xlabel("Parâmetro C")
plt.ylim(0, 1.1)
plt.show()
Resumo
Neste laboratório, realizámos validação cruzada de 10 dobras com um modelo SVM no conjunto de dados de dígitos, testando diferentes valores do parâmetro de regularização C. Plotamos os resultados para visualizar a relação entre C e a pontuação média de validação cruzada. Esta é uma técnica útil para ajustar hiperparâmetros e avaliar o desempenho do modelo.