Introdução
Neste laboratório, aprenderemos a utilizar um algoritmo de regressão logística para classificar dígitos manuscritos do conjunto de dados MNIST. Utilizaremos o algoritmo SAGA para ajustar uma regressão logística multinomial com penalidade L1 numa subamostra da tarefa de classificação de dígitos MNIST.
Dicas da Máquina Virtual
Após o arranque da VM, clique no canto superior esquerdo para mudar para a aba Notebook para aceder ao Jupyter Notebook para a prática.
Por vezes, pode ser necessário esperar alguns segundos para o Jupyter Notebook terminar o carregamento. A validação das operações não pode ser automatizada devido a limitações no Jupyter Notebook.
Se tiver problemas durante o aprendizado, não hesite em contactar o Labby. Forneça feedback após a sessão e resolveremos o problema rapidamente para si.
Importar Bibliotecas
Começaremos importando as bibliotecas necessárias para este laboratório. Usaremos a biblioteca scikit-learn para obter o conjunto de dados, treinar o modelo e avaliar o desempenho do modelo.
import time
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state
Carregar o conjunto de dados MNIST
Carregaremos o conjunto de dados MNIST utilizando a função fetch_openml do scikit-learn. Também selecionaremos um subconjunto dos dados definindo o número de train_samples para 5000.
## Reduzir para uma convergência mais rápida
t0 = time.time()
train_samples = 5000
## Carregar dados de https://www.openml.org/d/554
X, y = fetch_openml(
"mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas"
)
Pré-processamento
Pré-processaremos os dados embaralhando-os, dividindo o conjunto de dados em conjuntos de treino e teste e escalando os dados usando StandardScaler.
random_state = check_random_state(0)
permutation = random_state.permutation(X.shape[0])
X = X[permutation]
y = y[permutation]
X = X.reshape((X.shape[0], -1))
X_train, X_test, y_train, y_test = train_test_split(
X, y, train_size=train_samples, test_size=10000
)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
Treinar o Modelo
Treinaremos o modelo usando regressão logística com penalidade L1 e algoritmo SAGA. Definiremos o valor de C como 50.0 dividido pelo número de amostras de treino.
## Aumentar a tolerância para uma convergência mais rápida
clf = LogisticRegression(C=50.0 / train_samples, penalty="l1", solver="saga", tol=0.1)
clf.fit(X_train, y_train)
Avaliar o Modelo
Avaliaremos o desempenho do modelo calculando a esparcidade e a pontuação de precisão.
sparsity = np.mean(clf.coef_ == 0) * 100
score = clf.score(X_test, y_test)
print("Esparsidade com penalidade L1: %.2f%%" % sparsity)
print("Pontuação de teste com penalidade L1: %.4f" % score)
Visualizar o Modelo
Visualizaremos o modelo plotando os vetores de classificação para cada classe.
coef = clf.coef_.copy()
plt.figure(figsize=(10, 5))
scale = np.abs(coef).max()
for i in range(10):
l1_plot = plt.subplot(2, 5, i + 1)
l1_plot.imshow(
coef[i].reshape(28, 28),
interpolation="nearest",
cmap=plt.cm.RdBu,
vmin=-scale,
vmax=scale,
)
l1_plot.set_xticks(())
l1_plot.set_yticks(())
l1_plot.set_xlabel("Classe %i" % i)
plt.suptitle("Vetor de classificação para...")
run_time = time.time() - t0
print("Exemplo de execução em %.3f s" % run_time)
plt.show()
Resumo
Neste laboratório, aprendemos a utilizar a regressão logística para classificar dígitos manuscritos do conjunto de dados MNIST. Também aprendemos a utilizar o algoritmo SAGA com penalidade L1 para regressão logística. Alcançamos uma pontuação de precisão superior a 0,8 com um vetor de pesos esparso, tornando o modelo mais interpretável. No entanto, também observamos que esta precisão está significativamente abaixo do que pode ser alcançado por um modelo linear penalizado L2 ou um modelo perceptron multicamadas não linear neste conjunto de dados.