计算得分和计算时间
我们将使用GridSearchCV
对象的cv_results_
属性来计算每个超参数组合的平均拟合时间和得分时间。然后,我们将使用plotly.express.scatter
和plotly.express.line
绘制结果,以可视化经过的计算时间与平均测试得分之间的权衡。
import plotly.express as px
import plotly.colors as colors
from plotly.subplots import make_subplots
fig = make_subplots(
rows=1,
cols=2,
shared_yaxes=True,
subplot_titles=["Train time vs score", "Predict time vs score"],
)
model_names = [result["model"] for result in results]
colors_list = colors.qualitative.Plotly * (
len(model_names) // len(colors.qualitative.Plotly) + 1
)
for idx, result in enumerate(results):
cv_results = result["cv_results"].round(3)
model_name = result["model"]
param_name = list(param_grids[model_name].keys())[0]
cv_results[param_name] = cv_results["param_" + param_name]
cv_results["model"] = model_name
scatter_fig = px.scatter(
cv_results,
x="mean_fit_time",
y="mean_test_score",
error_x="std_fit_time",
error_y="std_test_score",
hover_data=param_name,
color="model",
)
line_fig = px.line(
cv_results,
x="mean_fit_time",
y="mean_test_score",
)
scatter_trace = scatter_fig["data"][0]
line_trace = line_fig["data"][0]
scatter_trace.update(marker=dict(color=colors_list[idx]))
line_trace.update(line=dict(color=colors_list[idx]))
fig.add_trace(scatter_trace, row=1, col=1)
fig.add_trace(line_trace, row=1, col=1)
scatter_fig = px.scatter(
cv_results,
x="mean_score_time",
y="mean_test_score",
error_x="std_score_time",
error_y="std_test_score",
hover_data=param_name,
)
line_fig = px.line(
cv_results,
x="mean_score_time",
y="mean_test_score",
)
scatter_trace = scatter_fig["data"][0]
line_trace = line_fig["data"][0]
scatter_trace.update(marker=dict(color=colors_list[idx]))
line_trace.update(line=dict(color=colors_list[idx]))
fig.add_trace(scatter_trace, row=1, col=2)
fig.add_trace(line_trace, row=1, col=2)
fig.update_layout(
xaxis=dict(title="Train time (s) - lower is better"),
yaxis=dict(title="Test R2 score - higher is better"),
xaxis2=dict(title="Predict time (s) - lower is better"),
legend=dict(x=0.72, y=0.05, traceorder="normal", borderwidth=1),
title=dict(x=0.5, text="Speed-score trade-off of tree-based ensembles"),
)