Calculer les scores et les temps de calcul
Nous allons calculer les temps moyens d'ajustement et de scoring pour chaque combinaison d'hyperparamètres en utilisant l'attribut cv_results_
de l'objet GridSearchCV
. Nous tracerons ensuite les résultats à l'aide de plotly.express.scatter
et plotly.express.line
pour visualiser le compromis entre le temps de calcul écoulé et la moyenne du score de test.
import plotly.express as px
import plotly.colors as colors
from plotly.subplots import make_subplots
fig = make_subplots(
rows=1,
cols=2,
shared_yaxes=True,
subplot_titles=["Temps d'entraînement vs score", "Temps de prédiction vs score"],
)
model_names = [result["model"] for result in results]
colors_list = colors.qualitative.Plotly * (
len(model_names) // len(colors.qualitative.Plotly) + 1
)
for idx, result in enumerate(results):
cv_results = result["cv_results"].round(3)
model_name = result["model"]
param_name = list(param_grids[model_name].keys())[0]
cv_results[param_name] = cv_results["param_" + param_name]
cv_results["model"] = model_name
scatter_fig = px.scatter(
cv_results,
x="mean_fit_time",
y="mean_test_score",
error_x="std_fit_time",
error_y="std_test_score",
hover_data=param_name,
color="model",
)
line_fig = px.line(
cv_results,
x="mean_fit_time",
y="mean_test_score",
)
scatter_trace = scatter_fig["data"][0]
line_trace = line_fig["data"][0]
scatter_trace.update(marker=dict(color=colors_list[idx]))
line_trace.update(line=dict(color=colors_list[idx]))
fig.add_trace(scatter_trace, row=1, col=1)
fig.add_trace(line_trace, row=1, col=1)
scatter_fig = px.scatter(
cv_results,
x="mean_score_time",
y="mean_test_score",
error_x="std_score_time",
error_y="std_test_score",
hover_data=param_name,
)
line_fig = px.line(
cv_results,
x="mean_score_time",
y="mean_test_score",
)
scatter_trace = scatter_fig["data"][0]
line_trace = line_fig["data"][0]
scatter_trace.update(marker=dict(color=colors_list[idx]))
line_trace.update(line=dict(color=colors_list[idx]))
fig.add_trace(scatter_trace, row=1, col=2)
fig.add_trace(line_trace, row=1, col=2)
fig.update_layout(
xaxis=dict(title="Temps d'entraînement (s) - moins élevé est meilleur"),
yaxis=dict(title="Score R2 de test - plus élevé est meilleur"),
xaxis2=dict(title="Temps de prédiction (s) - moins élevé est meilleur"),
legend=dict(x=0.72, y=0.05, traceorder="normal", borderwidth=1),
title=dict(x=0.5, text="Compromis vitesse - score des ensembles basés sur des arbres"),
)