使用 Scikit-learn 进行决策树分类

Beginner

This tutorial is from open-source community. Access the source code

简介

在本实验中,我们将学习如何使用 scikit-learn 中的决策树进行分类。决策树是一种用于分类和回归的非参数监督学习方法。它们易于理解和解释,并且可以处理数值型和类别型数据。

虚拟机使用提示

虚拟机启动完成后,点击左上角切换到笔记本标签页,以访问 Jupyter Notebook 进行练习。

有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。

如果你在学习过程中遇到问题,随时向 Labby 提问。课程结束后提供反馈,我们会及时为你解决问题。

导入必要的库

首先,我们需要导入必要的库。我们将使用 scikit-learn 来构建和训练决策树分类器。

from sklearn import tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

加载数据集

接下来,我们将加载鸢尾花数据集。该数据集包含了三种不同鸢尾花的四个特征的信息。我们将使用这个数据集来训练我们的决策树分类器。

## 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target

划分数据集

在训练决策树分类器之前,我们需要将数据集划分为训练集和测试集。我们将使用 70% 的数据进行训练,30% 的数据进行测试。

## 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

创建并训练决策树分类器

现在,我们可以使用训练数据创建并训练决策树分类器。

## 创建一个决策树分类器
clf = tree.DecisionTreeClassifier()

## 训练分类器
clf.fit(X_train, y_train)

进行预测

一旦分类器训练完成,我们就可以用它对测试数据进行预测。

## 对测试数据进行预测
y_pred = clf.predict(X_test)

## 打印预测值
print("Predicted values:", y_pred)

评估模型

最后,我们可以通过将预测值与真实值进行比较来评估模型的准确性。

## 计算模型的准确性
accuracy = accuracy_score(y_test, y_pred)

## 打印准确性
print("Accuracy:", accuracy)

总结

在本实验中,我们学习了如何使用 scikit-learn 中的决策树进行分类。我们加载了鸢尾花数据集,将数据拆分为训练集和测试集,创建并训练了决策树分类器,对测试数据进行了预测,并评估了模型的准确性。决策树是用于分类任务的一种强大且可解释的方法。