Introduction
La validation croisée est une technique pour évaluer les modèles d'apprentissage automatique en entraînant plusieurs modèles sur différents sous-ensembles des données disponibles et en les évaluant sur le sous-ensemble complémentaire. Cela aide à éviter le surapprentissage et assure que le modèle est généralisable. Dans ce tutoriel, nous utiliserons scikit-learn pour explorer différentes techniques de validation croisée et leur comportement.
Conseils sur la machine virtuelle
Une fois le démarrage de la machine virtuelle terminé, cliquez dans le coin supérieur gauche pour basculer vers l'onglet Carnet de notes pour accéder au carnet Jupyter pour la pratique.
Parfois, vous devrez peut-être attendre quelques secondes pour que le carnet Jupyter ait fini de charger. La validation des opérations ne peut pas être automatisée en raison des limitations du carnet Jupyter.
Si vous rencontrez des problèmes pendant l'apprentissage, n'hésitez pas à demander à Labby. Donnez votre feedback après la session, et nous réglerons rapidement le problème pour vous.
Visualiser les données
Dans cette étape, nous allons visualiser les données avec lesquelles nous allons travailler. Les données sont composées de 100 points de données d'entrée générés aléatoirement, de 3 classes réparties de manière inégale parmi les points de données et de 10 "groupes" répartis régulièrement parmi les points de données.
## Générer les données de classe/groupe
n_points = 100
X = rng.randn(100, 10)
percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])
## Générer des groupes inégaux
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))
def visualize_groups(classes, groups, name):
## Visualiser les groupes du jeu de données
fig, ax = plt.subplots()
ax.scatter(
range(len(groups)),
[0.5] * len(groups),
c=groups,
marker="_",
lw=50,
cmap=cmap_data,
)
ax.scatter(
range(len(groups)),
[3.5] * len(groups),
c=classes,
marker="_",
lw=50,
cmap=cmap_data,
)
ax.set(
ylim=[-1, 5],
yticks=[0.5, 3.5],
yticklabels=["Données\ngroupe", "Données\nclasse"],
xlabel="Index d'échantillonnage"
)
visualize_groups(y, groups, "sans groupes")
Visualiser les indices de validation croisée
Dans cette étape, nous allons définir une fonction pour visualiser le comportement de chaque objet de validation croisée. Nous effectuerons 4 partitions des données. Pour chaque partition, nous visualiserons les indices choisis pour l'ensemble d'entraînement (en bleu) et l'ensemble de test (en rouge).
def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
"""Créer un graphique d'échantillonnage pour les indices d'un objet de validation croisée."""
## Générer les visualisations d'entraînement/tests pour chaque partition CV
for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):
## Remplir les indices avec les groupes d'entraînement/tests
indices = np.array([np.nan] * len(X))
indices[tt] = 1
indices[tr] = 0
## Visualiser les résultats
ax.scatter(
range(len(indices)),
[ii + 0.5] * len(indices),
c=indices,
marker="_",
lw=lw,
cmap=cmap_cv,
vmin=-0.2,
vmax=1.2,
)
## Tracer les classes et les groupes de données à la fin
ax.scatter(
range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
)
ax.scatter(
range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
)
## Mise en forme
yticklabels = list(range(n_splits)) + ["classe", "groupe"]
ax.set(
yticks=np.arange(n_splits + 2) + 0.5,
yticklabels=yticklabels,
xlabel="Index d'échantillonnage",
ylabel="Itération de validation croisée",
ylim=[n_splits + 2.2, -0.2],
xlim=[0, 100],
)
ax.set_title("{}".format(type(cv).__name__), fontsize=15)
return ax
Comparer les objets de validation croisée
Dans cette étape, nous allons comparer le comportement de la validation croisée pour différents objets de validation croisée de scikit-learn. Nous allons parcourir plusieurs objets de validation croisée courants, en visualisant le comportement de chacun. Notez comment certains utilisent les informations de groupe/classe tandis que d'autres ne le font pas.
cvs = [
KFold,
GroupKFold,
ShuffleSplit,
StratifiedKFold,
StratifiedGroupKFold,
GroupShuffleSplit,
StratifiedShuffleSplit,
TimeSeriesSplit,
]
for cv in cvs:
this_cv = cv(n_splits=n_splits)
fig, ax = plt.subplots(figsize=(6, 3))
plot_cv_indices(this_cv, X, y, groups, ax, n_splits)
ax.legend(
[Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))],
["Ensemble de test", "Ensemble d'entraînement"],
loc=(1.02, 0.8),
)
## Ajuster la légende pour qu'elle rentre
plt.tight_layout()
fig.subplots_adjust(right=0.7)
plt.show()
Sommaire
Dans ce tutoriel, nous avons utilisé scikit-learn pour explorer différentes techniques de validation croisée et leur comportement. Nous avons visualisé les données, défini une fonction pour visualiser les indices de validation croisée et comparé le comportement de la validation croisée pour différents objets de validation croisée de scikit-learn.