Regresión logística multinomial MNIST

Beginner

This tutorial is from open-source community. Access the source code

Introducción

En este laboratorio, aprenderemos a usar un algoritmo de regresión logística para clasificar dígitos manuscritos del conjunto de datos MNIST. Usaremos el algoritmo SAGA para ajustar una regresión logística multinomial con penalización L1 en un subconjunto de la tarea de clasificación de dígitos MNIST.

Consejos sobre la VM

Una vez finalizada la inicialización de la VM, haga clic en la esquina superior izquierda para cambiar a la pestaña Cuaderno y acceder a Jupyter Notebook para practicar.

A veces, es posible que tenga que esperar unos segundos a que Jupyter Notebook termine de cargarse. La validación de las operaciones no se puede automatizar debido a las limitaciones de Jupyter Notebook.

Si tiene problemas durante el aprendizaje, no dude en preguntar a Labby. Deje sus comentarios después de la sesión y lo resolveremos rápidamente para usted.

Importar bibliotecas

Comenzaremos importando las bibliotecas necesarias para este laboratorio. Usaremos la biblioteca scikit-learn para obtener el conjunto de datos, entrenar el modelo y evaluar el rendimiento del modelo.

import time
import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_random_state

Cargar el conjunto de datos MNIST

Cargaremos el conjunto de datos MNIST usando la función fetch_openml de scikit-learn. También seleccionaremos un subconjunto de los datos fijando el número de train_samples en 5000.

## Turn down for faster convergence
t0 = time.time()
train_samples = 5000

## Load data from https://www.openml.org/d/554
X, y = fetch_openml(
    "mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas"
)

Preprocesamiento

Preprocesaremos los datos mezclando los datos, dividiendo el conjunto de datos en conjuntos de entrenamiento y prueba y escalando los datos usando StandardScaler.

random_state = check_random_state(0)
permutación = random_state.permutación(X.shape[0])
X = X[permutación]
y = y[permutación]
X = X.reshape((X.shape[0], -1))

X_train, X_test, y_train, y_test = train_test_split(
    X, y, train_size=train_samples, test_size=10000
)

escalador = StandardScaler()
X_train = escalador.fit_transform(X_train)
X_test = escalador.transform(X_test)

Entrenar el modelo

Entrenaremos el modelo usando regresión logística con penalización L1 y el algoritmo SAGA. Estableceremos el valor de C en 50.0 dividido por el número de muestras de entrenamiento.

## Turn up tolerance for faster convergence
clf = LogisticRegression(C=50.0 / train_samples, penalty="l1", solver="saga", tol=0.1)
clf.fit(X_train, y_train)

Evaluar el modelo

Evaluaremos el rendimiento del modelo calculando la esparsidad y la puntuación de precisión.

esparsidad = np.mean(clf.coef_ == 0) * 100
puntuación = clf.score(X_test, y_test)

print("Esparsidad con penalización L1: %.2f%%" % esparsidad)
print("Puntuación de prueba con penalización L1: %.4f" % puntuación)

Visualizar el modelo

Visualizaremos el modelo trazando los vectores de clasificación para cada clase.

coef = clf.coef_.copy()
plt.figure(figsize=(10, 5))
scale = np.abs(coef).max()
for i in range(10):
    l1_plot = plt.subplot(2, 5, i + 1)
    l1_plot.imshow(
        coef[i].reshape(28, 28),
        interpolation="nearest",
        cmap=plt.cm.RdBu,
        vmin=-scale,
        vmax=scale,
    )
    l1_plot.set_xticks(())
    l1_plot.set_yticks(())
    l1_plot.set_xlabel("Clase %i" % i)
plt.suptitle("Vector de clasificación para...")

run_time = time.time() - t0
print("Ejemplo ejecutado en %.3f s" % run_time)
plt.show()

Resumen

En este laboratorio, aprendimos cómo usar la regresión logística para clasificar dígitos manuscritos del conjunto de datos MNIST. También aprendimos cómo usar el algoritmo SAGA con penalización L1 para la regresión logística. Logramos una puntuación de precisión superior a 0,8 con un vector de pesos esparso, lo que hace que el modelo sea más interpretable. Sin embargo, también notamos que esta precisión está significativamente por debajo de lo que puede alcanzarse por un modelo lineal penalizado con L2 o un modelo de perceptrón multicapa no lineal en este conjunto de datos.