导入所需库并加载数据集
在这一步中,你将学习如何导入所需的库并加载鸢尾花数据集。按照以下步骤完成此步骤:
在iris_classification_svm.py
中,导入所需的库,包括用于加载数据集、拆分数据、创建支持向量机(SVM)模型以及评估其性能的库。
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, classification_report
从sklearn.datasets
加载鸢尾花数据,并将数据集拆分为训练集和测试集。数据集按80-20的比例进行训练和测试拆分,随机种子设为42以确保可重复性。
## 在同一个文件中继续
def load_and_split_data() -> tuple:
"""
返回值:
元组:[X_train, X_test, y_train, y_test]
"""
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
return X_train, X_test, y_train, y_test
这段代码加载鸢尾花数据集并将其拆分为训练集和测试集,用于机器学习。以下是各部分的详细说明:
- 导入必要的库:
sklearn.datasets
用于加载数据集,包括鸢尾花数据集。
sklearn.model_selection
提供将数据集拆分为训练集和测试集的实用工具。
sklearn.svm
包含支持向量机(SVM)的类,这是一种机器学习算法。
sklearn.metrics
包括用于评估模型性能的工具,如准确率和分类报告。
- 函数定义:定义了一个名为
load_and_split_data
的函数。此函数执行以下任务:
- 加载鸢尾花数据集:
load_iris()
是sklearn.datasets
提供的一个函数,用于加载鸢尾花数据集,这是一个用于分类任务的流行数据集。它包含来自三个不同物种的150朵鸢尾花的测量数据。
- 数据分离:数据集被分离为特征(
X
)和目标标签(y
)。在这种情况下,X
将是鸢尾花的4维测量数据,而y
将是相应的物种标签(0、1或2)。
- 拆分数据:使用
sklearn.model_selection
中的train_test_split
将数据拆分为训练子集和测试子集。test_size = 0.2
参数表示20%的数据将用于测试,其余80%将用于训练。random_state = 42
确保拆分的可重复性;每次运行代码时使用相同的种子(这里是42)将产生相同的拆分结果。
- 返回值:该函数返回一个包含
X_train
、X_test
、y_train
和y_test
的元组,它们分别是训练数据和测试数据的特征集和目标集。