Validación cruzada en el conjunto de datos de dígitos

Beginner

This tutorial is from open-source community. Access the source code

Introducción

Esta práctica utiliza validación cruzada con una máquina de vectores de soporte (SVM) en el conjunto de datos de dígitos. Este es un problema de clasificación, donde la tarea es identificar dígitos a partir de imágenes de dígitos manuscritos.

Consejos sobre la VM

Una vez finalizada la inicialización de la VM, haga clic en la esquina superior izquierda para cambiar a la pestaña Cuaderno y acceder a Jupyter Notebook para practicar.

En ocasiones, es posible que tenga que esperar unos segundos a que Jupyter Notebook termine de cargarse. La validación de operaciones no se puede automatizar debido a las limitaciones de Jupyter Notebook.

Si tiene problemas durante el aprendizaje, no dude en preguntar a Labby. Deje su retroalimentación después de la sesión y lo resolveremos rápidamente para usted.

Cargar el conjunto de datos

En primer lugar, necesitamos cargar el conjunto de datos de dígitos de scikit-learn y dividirlo en características y etiquetas.

import numpy as np
from sklearn import datasets

X, y = datasets.load_digits(return_X_y=True)

Crear un modelo de Máquina de Vectores de Soporte (SVM)

A continuación, creamos un modelo de SVM con un kernel lineal.

from sklearn import svm

svc = svm.SVC(kernel="linear")

Definir los valores de hiperparámetros para probar

Probaremos diferentes valores del parámetro de regularización C, que controla el equilibrio entre maximizar el margen y minimizar el error de clasificación. Probaremos 10 valores espaciados logarítmicamente entre 10^-10 y 1.

C_s = np.logspace(-10, 0, 10)

Realizar validación cruzada y registrar resultados

Para cada valor de C, realizamos una validación cruzada de 10 pliegues y registramos la media y la desviación estándar de las puntuaciones.

from sklearn.model_selection import cross_val_score

scores = list()
scores_std = list()
for C in C_s:
    svc.C = C
    this_scores = cross_val_score(svc, X, y, n_jobs=1)
    scores.append(np.mean(this_scores))
    scores_std.append(np.std(this_scores))

Graficar los resultados

Finalmente, graficamos las puntuaciones promedio en función de C, y también incluimos barras de error para visualizar la desviación estándar.

import matplotlib.pyplot as plt

plt.figure()
plt.semilogx(C_s, scores)
plt.semilogx(C_s, np.array(scores) + np.array(scores_std), "b--")
plt.semilogx(C_s, np.array(scores) - np.array(scores_std), "b--")
locs, labels = plt.yticks()
plt.yticks(locs, list(map(lambda x: "%g" % x, locs)))
plt.ylabel("Puntuación CV")
plt.xlabel("Parámetro C")
plt.ylim(0, 1.1)
plt.show()

Resumen

En este laboratorio, realizamos una validación cruzada de 10 pliegues con un modelo de SVM en el conjunto de datos de dígitos, probando diferentes valores del parámetro de regularización C. Graficamos los resultados para visualizar la relación entre C y la puntuación promedio de validación cruzada. Esta es una técnica útil para ajustar hiperparámetros y evaluar el rendimiento del modelo.