Calcular puntuaciones y tiempos de cálculo
Calcularemos los tiempos de ajuste y puntuación promedio para cada combinación de hiperparámetros utilizando el atributo cv_results_
del objeto GridSearchCV
. Luego graficaremos los resultados utilizando plotly.express.scatter
y plotly.express.line
para visualizar el equilibrio entre el tiempo de cálculo transcurrido y la puntuación promedio de prueba.
import plotly.express as px
import plotly.colors as colors
from plotly.subplots import make_subplots
fig = make_subplots(
rows=1,
cols=2,
shared_yaxes=True,
subplot_titles=["Tiempo de entrenamiento vs puntuación", "Tiempo de predicción vs puntuación"],
)
model_names = [result["model"] for result in results]
colors_list = colors.qualitative.Plotly * (
len(model_names) // len(colors.qualitative.Plotly) + 1
)
for idx, result in enumerate(results):
cv_results = result["cv_results"].round(3)
model_name = result["model"]
param_name = list(param_grids[model_name].keys())[0]
cv_results[param_name] = cv_results["param_" + param_name]
cv_results["model"] = model_name
scatter_fig = px.scatter(
cv_results,
x="mean_fit_time",
y="mean_test_score",
error_x="std_fit_time",
error_y="std_test_score",
hover_data=param_name,
color="model",
)
line_fig = px.line(
cv_results,
x="mean_fit_time",
y="mean_test_score",
)
scatter_trace = scatter_fig["data"][0]
line_trace = line_fig["data"][0]
scatter_trace.update(marker=dict(color=colors_list[idx]))
line_trace.update(line=dict(color=colors_list[idx]))
fig.add_trace(scatter_trace, row=1, col=1)
fig.add_trace(line_trace, row=1, col=1)
scatter_fig = px.scatter(
cv_results,
x="mean_score_time",
y="mean_test_score",
error_x="std_score_time",
error_y="std_test_score",
hover_data=param_name,
)
line_fig = px.line(
cv_results,
x="mean_score_time",
y="mean_test_score",
)
scatter_trace = scatter_fig["data"][0]
line_trace = line_fig["data"][0]
scatter_trace.update(marker=dict(color=colors_list[idx]))
line_trace.update(line=dict(color=colors_list[idx]))
fig.add_trace(scatter_trace, row=1, col=2)
fig.add_trace(line_trace, row=1, col=2)
fig.update_layout(
xaxis=dict(title="Tiempo de entrenamiento (s) - menor es mejor"),
yaxis=dict(title="Puntuación R2 de prueba - mayor es mejor"),
xaxis2=dict(title="Tiempo de predicción (s) - menor es mejor"),
legend=dict(x=0.72, y=0.05, traceorder="normal", borderwidth=1),
title=dict(x=0.5, text="Equilibrio velocidad-puntuación de los conjuntos de árboles"),
)