Regressão Logística Multinomial para MNIST

Beginner

This tutorial is from open-source community. Access the source code

Introdução

Neste laboratório, aprenderemos a utilizar um algoritmo de regressão logística para classificar dígitos manuscritos do conjunto de dados MNIST. Utilizaremos o algoritmo SAGA para ajustar uma regressão logística multinomial com penalidade L1 numa subamostra da tarefa de classificação de dígitos MNIST.

Dicas da Máquina Virtual

Após o arranque da VM, clique no canto superior esquerdo para mudar para a aba Notebook para aceder ao Jupyter Notebook para a prática.

Por vezes, pode ser necessário esperar alguns segundos para o Jupyter Notebook terminar o carregamento. A validação das operações não pode ser automatizada devido a limitações no Jupyter Notebook.

Se tiver problemas durante o aprendizado, não hesite em contactar o Labby. Forneça feedback após a sessão e resolveremos o problema rapidamente para si.

Importar Bibliotecas

Começaremos importando as bibliotecas necessárias para este laboratório. Usaremos a biblioteca scikit-learn para obter o conjunto de dados, treinar o modelo e avaliar o desempenho do modelo.

import time
import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state

Carregar o conjunto de dados MNIST

Carregaremos o conjunto de dados MNIST utilizando a função fetch_openml do scikit-learn. Também selecionaremos um subconjunto dos dados definindo o número de train_samples para 5000.

## Reduzir para uma convergência mais rápida
t0 = time.time()
train_samples = 5000

## Carregar dados de https://www.openml.org/d/554
X, y = fetch_openml(
    "mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas"
)

Pré-processamento

Pré-processaremos os dados embaralhando-os, dividindo o conjunto de dados em conjuntos de treino e teste e escalando os dados usando StandardScaler.

random_state = check_random_state(0)
permutation = random_state.permutation(X.shape[0])
X = X[permutation]
y = y[permutation]
X = X.reshape((X.shape[0], -1))

X_train, X_test, y_train, y_test = train_test_split(
    X, y, train_size=train_samples, test_size=10000
)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

Treinar o Modelo

Treinaremos o modelo usando regressão logística com penalidade L1 e algoritmo SAGA. Definiremos o valor de C como 50.0 dividido pelo número de amostras de treino.

## Aumentar a tolerância para uma convergência mais rápida
clf = LogisticRegression(C=50.0 / train_samples, penalty="l1", solver="saga", tol=0.1)
clf.fit(X_train, y_train)

Avaliar o Modelo

Avaliaremos o desempenho do modelo calculando a esparcidade e a pontuação de precisão.

sparsity = np.mean(clf.coef_ == 0) * 100
score = clf.score(X_test, y_test)

print("Esparsidade com penalidade L1: %.2f%%" % sparsity)
print("Pontuação de teste com penalidade L1: %.4f" % score)

Visualizar o Modelo

Visualizaremos o modelo plotando os vetores de classificação para cada classe.

coef = clf.coef_.copy()
plt.figure(figsize=(10, 5))
scale = np.abs(coef).max()
for i in range(10):
    l1_plot = plt.subplot(2, 5, i + 1)
    l1_plot.imshow(
        coef[i].reshape(28, 28),
        interpolation="nearest",
        cmap=plt.cm.RdBu,
        vmin=-scale,
        vmax=scale,
    )
    l1_plot.set_xticks(())
    l1_plot.set_yticks(())
    l1_plot.set_xlabel("Classe %i" % i)
plt.suptitle("Vetor de classificação para...")

run_time = time.time() - t0
print("Exemplo de execução em %.3f s" % run_time)
plt.show()

Resumo

Neste laboratório, aprendemos a utilizar a regressão logística para classificar dígitos manuscritos do conjunto de dados MNIST. Também aprendemos a utilizar o algoritmo SAGA com penalidade L1 para regressão logística. Alcançamos uma pontuação de precisão superior a 0,8 com um vetor de pesos esparso, tornando o modelo mais interpretável. No entanto, também observamos que esta precisão está significativamente abaixo do que pode ser alcançado por um modelo linear penalizado L2 ou um modelo perceptron multicamadas não linear neste conjunto de dados.