决策树剪枝

Beginner

This tutorial is from open-source community. Access the source code

简介

在机器学习中,决策树是常用的模型。然而,决策树有过拟合训练数据的倾向,这可能导致它们在测试数据上表现不佳。防止过拟合的一种方法是对决策树进行剪枝。成本复杂度剪枝是一种流行的决策树剪枝方法。在本实验中,我们将使用 scikit-learn 来演示决策树的成本复杂度剪枝。

虚拟机使用提示

虚拟机启动完成后,点击左上角切换到笔记本标签页,以访问 Jupyter Notebook 进行练习。

有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作的验证无法自动化。

如果你在学习过程中遇到问题,随时向 Labby 提问。课程结束后提供反馈,我们会及时为你解决问题。

加载数据

我们将使用 scikit-learn 中的乳腺癌数据集。这个数据集有 30 个特征和一个二元目标变量,用于指示患者患的是恶性还是良性癌症。

from sklearn.datasets import load_breast_cancer

X, y = load_breast_cancer(return_X_y=True)

划分数据

我们将把数据划分为训练集和测试集。

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

确定合适的α值

我们想要确定用于修剪决策树的合适α值。我们可以通过绘制叶子节点的总杂质与修剪后树的有效α值的关系图来做到这一点。

from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt

clf = DecisionTreeClassifier(random_state=0)
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities

fig, ax = plt.subplots()
ax.plot(ccp_alphas[:-1], impurities[:-1], marker="o", drawstyle="steps-post")
ax.set_xlabel("effective alpha")
ax.set_ylabel("total impurity of leaves")
ax.set_title("Total Impurity vs effective alpha for training set")

训练决策树

接下来,我们将使用每个有效α值训练一棵决策树。ccp_alphas中的最后一个值是修剪整个树的α值,这样树就只剩下一个节点。

clfs = []
for ccp_alpha in ccp_alphas:
    clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
    clf.fit(X_train, y_train)
    clfs.append(clf)
print(
    "Number of nodes in the last tree is: {} with ccp_alpha: {}".format(
        clfs[-1].tree_.node_count, ccp_alphas[-1]
    )
)

移除平凡树

我们将从决策树列表中移除只有一个节点的平凡树。

clfs = clfs[:-1]
ccp_alphas = ccp_alphas[:-1]

绘制树的节点数量和深度

我们将绘制随着α值增加,树的节点数量和深度的变化情况。

node_counts = [clf.tree_.node_count for clf in clfs]
depth = [clf.tree_.max_depth for clf in clfs]
fig, ax = plt.subplots(2, 1)
ax[0].plot(ccp_alphas, node_counts, marker="o", drawstyle="steps-post")
ax[0].set_xlabel("alpha")
ax[0].set_ylabel("number of nodes")
ax[0].set_title("Number of nodes vs alpha")
ax[1].plot(ccp_alphas, depth, marker="o", drawstyle="steps-post")
ax[1].set_xlabel("alpha")
ax[1].set_ylabel("depth of tree")
ax[1].set_title("Depth vs alpha")
fig.tight_layout()

确定最佳α值

我们想要确定用于修剪决策树的最佳α值。我们可以通过绘制训练集和测试集的准确率与α的关系图来做到这一点。

train_scores = [clf.score(X_train, y_train) for clf in clfs]
test_scores = [clf.score(X_test, y_test) for clf in clfs]

fig, ax = plt.subplots()
ax.set_xlabel("alpha")
ax.set_ylabel("accuracy")
ax.set_title("Accuracy vs alpha for training and testing sets")
ax.plot(ccp_alphas, train_scores, marker="o", label="train", drawstyle="steps-post")
ax.plot(ccp_alphas, test_scores, marker="o", label="test", drawstyle="steps-post")
ax.legend()
plt.show()

总结

在本实验中,我们展示了如何使用 scikit-learn 对决策树执行成本复杂度剪枝。我们将数据拆分为训练集和测试集,确定用于剪枝的合适α值,使用有效的α值训练决策树,绘制树的节点数量和深度,并根据训练集和测试集的准确率确定用于剪枝的最佳α值。