线性回归拟合与绘图

Machine LearningMachine LearningBeginner
立即练习

💡 本教程由 AI 辅助翻译自英文原版。如需查看原文,您可以 切换至英文原版

简介

在这个项目中,你将学习如何对一组数据点进行线性回归,并使用Matplotlib可视化结果。线性回归是一种基本的机器学习技术,用于对因变量(y)和一个或多个自变量(x)之间的关系进行建模。

🎯 任务

在这个项目中,你将学习:

  • 如何将给定数据转换为Numpy数组以便于操作
  • 如何计算线性回归模型的系数,包括斜率(w)和截距(b)
  • 如何在散点图上绘制数据点,并在同一图上绘制线性回归直线

🏆 成果

完成这个项目后,你将能够:

  • 为线性回归分析准备数据
  • 使用Numpy函数计算线性回归参数
  • 使用Matplotlib创建散点图并叠加线性回归直线
  • 更好地理解线性回归及其在数据分析和可视化中的实际应用

Skills Graph

%%%%{init: {'theme':'neutral'}}%%%% flowchart RL ml(("Machine Learning")) -.-> ml/BasicConceptsGroup(["Basic Concepts"]) ml(("Machine Learning")) -.-> ml/RegressionAlgorithmsGroup(["Regression Algorithms"]) ml(("Machine Learning")) -.-> ml/EvaluationMetricsGroup(["Evaluation Metrics"]) ml(("Machine Learning")) -.-> ml/FrameworkandSoftwareGroup(["Framework and Software"]) ml/BasicConceptsGroup -.-> ml/basic_concept("Basic Concept") ml/RegressionAlgorithmsGroup -.-> ml/linear_regression("Linear Regression") ml/EvaluationMetricsGroup -.-> ml/mse("Mean Squared Error") ml/FrameworkandSoftwareGroup -.-> ml/sklearn("scikit-learn") subgraph Lab Skills ml/basic_concept -.-> lab-300236{{"线性回归拟合与绘图"}} ml/linear_regression -.-> lab-300236{{"线性回归拟合与绘图"}} ml/mse -.-> lab-300236{{"线性回归拟合与绘图"}} ml/sklearn -.-> lab-300236{{"线性回归拟合与绘图"}} end

将数据转换为NumPy数组

在这一步中,你将学习如何将给定数据转换为NumPy数组以便于操作。

打开 linear_regression_plot.py 文件。

linear_plot() 函数开头找到 data 变量。

使用 np.array() 函数将 data 列表转换为NumPy数组。

data = np.array(data)

现在,data 变量是一个NumPy数组,这将使它在接下来的步骤中更易于处理。

✨ 查看解决方案并练习

计算线性回归参数

在这一步中,你将学习如何根据数据计算系数 w 和截距 b 的值。

继续在 linear_plot() 函数中,从NumPy数组 data 中提取 xy 值。

x = data[:, 0]
y = data[:, 1]

使用 np.polyfit() 函数计算线性回归参数 wb

w, b = np.polyfit(x, y, 1)

使用 round() 函数将计算出的 wb 值保留两位小数。

w = round(w, 2)
b = round(b, 2)

现在,你已经得到了系数 w 和截距 b 的值,可以在下一步中使用。

✨ 查看解决方案并练习

绘制线性回归线

在这一步中,你将学习如何绘制样本散点图,并根据计算出的参数在图上绘制拟合线。

继续在 linear_plot() 函数中,使用 plt.subplots() 创建一个新的Matplotlib图形和坐标轴。

fig, ax = plt.subplots()

使用 ax.scatter() 函数绘制数据点。

ax.scatter(x, y, label="Data Points")

使用 ax.plot() 函数以及计算出的 wb 值绘制线性回归线。

ax.plot(x, w * x + b, color="red", label=f"Linear Fit: y = {w}x + {b}")

为该图添加标签和图例。

ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.legend()

最后,返回 wb 的值以及 fig 对象。

return w, b, fig

现在,你已经完成了 linear_plot() 函数,该函数对给定数据执行线性回归,并返回系数和Matplotlib绘图对象。

✨ 查看解决方案并练习

运行线性回归绘图

在这最后一步中,你将学习如何执行你的脚本,以便看到线性回归绘图的实际效果。代码的这一部分将使用先前在 linear_plot() 中定义的函数来显示散点图以及线性回归线。

提供的Python代码片段将检查该脚本是否作为主程序运行,如果是,它将调用 linear_plot() 函数并显示生成的绘图。

if __name__ == "__main__":
    ## 调用linear_plot函数来计算回归参数并生成绘图
    w, b, fig = linear_plot()
    ## 显示绘图
    plt.show()

在这里,if __name__ == "__main__": 确保只有当脚本直接运行时才调用 linear_plot() 函数,而不是在作为模块导入时调用。在调用返回斜率 w、截距 b 和图形对象 figlinear_plot() 之后,使用 plt.show() 在窗口中显示绘图。这使你能够直观地检查回归线对数据的拟合情况。

现在你可以按下 linear_regression_plot.py 第一行顶部的“运行单元格”按钮并查看结果。

线性回归绘图结果

通过这种设置,你可以轻松地使用不同的数据集重新运行脚本,或者对回归计算进行调整,以立即获得视觉反馈。

✨ 查看解决方案并练习

总结

恭喜你!你已经完成了这个项目。你可以在LabEx中练习更多实验来提升你的技能。