绘制结果
我们将绘制每个模型的测试误差、分类误差和提升权重。
n_trees_discrete = len(bdt_discrete)
n_trees_real = len(bdt_real)
## Boosting might terminate early, but the following arrays are always
## n_estimators long. We crop them to the actual number of trees here:
discrete_estimator_errors = bdt_discrete.estimator_errors_[:n_trees_discrete]
real_estimator_errors = bdt_real.estimator_errors_[:n_trees_real]
discrete_estimator_weights = bdt_discrete.estimator_weights_[:n_trees_discrete]
plt.figure(figsize=(15, 5))
plt.subplot(131)
plt.plot(range(1, n_trees_discrete + 1), discrete_test_errors, c="black", label="SAMME")
plt.plot(
range(1, n_trees_real + 1),
real_test_errors,
c="black",
linestyle="dashed",
label="SAMME.R",
)
plt.legend()
plt.ylim(0.18, 0.62)
plt.ylabel("测试误差")
plt.xlabel("树的数量")
plt.subplot(132)
plt.plot(
range(1, n_trees_discrete + 1),
discrete_estimator_errors,
"b",
label="SAMME",
alpha=0.5,
)
plt.plot(
range(1, n_trees_real + 1), real_estimator_errors, "r", label="SAMME.R", alpha=0.5
)
plt.legend()
plt.ylabel("误差")
plt.xlabel("树的数量")
plt.ylim((0.2, max(real_estimator_errors.max(), discrete_estimator_errors.max()) * 1.2))
plt.xlim((-20, len(bdt_discrete) + 20))
plt.subplot(133)
plt.plot(range(1, n_trees_discrete + 1), discrete_estimator_weights, "b", label="SAMME")
plt.legend()
plt.ylabel("权重")
plt.xlabel("树的数量")
plt.ylim((0, discrete_estimator_weights.max() * 1.2))
plt.xlim((-20, n_trees_discrete + 20))
## 防止 y 轴标签重叠
plt.subplots_adjust(wspace=0.25)
plt.show()