Compute Scores and Computation Times
We will compute the mean fit and score times for each combination of hyperparameters using the cv_results_
attribute of the GridSearchCV
object. We will then plot the results using plotly.express.scatter
and plotly.express.line
to visualize the trade-off between elapsed computing time and mean test score.
import plotly.express as px
import plotly.colors as colors
from plotly.subplots import make_subplots
fig = make_subplots(
rows=1,
cols=2,
shared_yaxes=True,
subplot_titles=["Train time vs score", "Predict time vs score"],
)
model_names = [result["model"] for result in results]
colors_list = colors.qualitative.Plotly * (
len(model_names) // len(colors.qualitative.Plotly) + 1
)
for idx, result in enumerate(results):
cv_results = result["cv_results"].round(3)
model_name = result["model"]
param_name = list(param_grids[model_name].keys())[0]
cv_results[param_name] = cv_results["param_" + param_name]
cv_results["model"] = model_name
scatter_fig = px.scatter(
cv_results,
x="mean_fit_time",
y="mean_test_score",
error_x="std_fit_time",
error_y="std_test_score",
hover_data=param_name,
color="model",
)
line_fig = px.line(
cv_results,
x="mean_fit_time",
y="mean_test_score",
)
scatter_trace = scatter_fig["data"][0]
line_trace = line_fig["data"][0]
scatter_trace.update(marker=dict(color=colors_list[idx]))
line_trace.update(line=dict(color=colors_list[idx]))
fig.add_trace(scatter_trace, row=1, col=1)
fig.add_trace(line_trace, row=1, col=1)
scatter_fig = px.scatter(
cv_results,
x="mean_score_time",
y="mean_test_score",
error_x="std_score_time",
error_y="std_test_score",
hover_data=param_name,
)
line_fig = px.line(
cv_results,
x="mean_score_time",
y="mean_test_score",
)
scatter_trace = scatter_fig["data"][0]
line_trace = line_fig["data"][0]
scatter_trace.update(marker=dict(color=colors_list[idx]))
line_trace.update(line=dict(color=colors_list[idx]))
fig.add_trace(scatter_trace, row=1, col=2)
fig.add_trace(line_trace, row=1, col=2)
fig.update_layout(
xaxis=dict(title="Train time (s) - lower is better"),
yaxis=dict(title="Test R2 score - higher is better"),
xaxis2=dict(title="Predict time (s) - lower is better"),
legend=dict(x=0.72, y=0.05, traceorder="normal", borderwidth=1),
title=dict(x=0.5, text="Speed-score trade-off of tree-based ensembles"),
)