점수 및 계산 시간 계산
GridSearchCV 객체의 cv_results_ 속성을 사용하여 각 하이퍼파라미터 조합에 대한 평균 적합 및 점수 시간을 계산합니다. 그런 다음 plotly.express.scatter 및 plotly.express.line을 사용하여 결과를 시각화하여 계산 시간과 평균 테스트 점수 간의 트레이드오프를 시각화합니다.
import plotly.express as px
import plotly.colors as colors
from plotly.subplots import make_subplots
fig = make_subplots(
rows=1,
cols=2,
shared_yaxes=True,
subplot_titles=["Train time vs score", "Predict time vs score"],
)
model_names = [result["model"] for result in results]
colors_list = colors.qualitative.Plotly * (
len(model_names) // len(colors.qualitative.Plotly) + 1
)
for idx, result in enumerate(results):
cv_results = result["cv_results"].round(3)
model_name = result["model"]
param_name = list(param_grids[model_name].keys())[0]
cv_results[param_name] = cv_results["param_" + param_name]
cv_results["model"] = model_name
scatter_fig = px.scatter(
cv_results,
x="mean_fit_time",
y="mean_test_score",
error_x="std_fit_time",
error_y="std_test_score",
hover_data=param_name,
color="model",
)
line_fig = px.line(
cv_results,
x="mean_fit_time",
y="mean_test_score",
)
scatter_trace = scatter_fig["data"][0]
line_trace = line_fig["data"][0]
scatter_trace.update(marker=dict(color=colors_list[idx]))
line_trace.update(line=dict(color=colors_list[idx]))
fig.add_trace(scatter_trace, row=1, col=1)
fig.add_trace(line_trace, row=1, col=1)
scatter_fig = px.scatter(
cv_results,
x="mean_score_time",
y="mean_test_score",
error_x="std_score_time",
error_y="std_test_score",
hover_data=param_name,
)
line_fig = px.line(
cv_results,
x="mean_score_time",
y="mean_test_score",
)
scatter_trace = scatter_fig["data"][0]
line_trace = line_fig["data"][0]
scatter_trace.update(marker=dict(color=colors_list[idx]))
line_trace.update(line=dict(color=colors_list[idx]))
fig.add_trace(scatter_trace, row=1, col=2)
fig.add_trace(line_trace, row=1, col=2)
fig.update_layout(
xaxis=dict(title="Train time (s) - lower is better"),
yaxis=dict(title="Test R2 score - higher is better"),
xaxis2=dict(title="Predict time (s) - lower is better"),
legend=dict(x=0.72, y=0.05, traceorder="normal", borderwidth=1),
title=dict(x=0.5, text="Speed-score trade-off of tree-based ensembles"),
)