介绍
在这个项目中,你将学习如何对一组数据点进行线性回归,并使用 Matplotlib 可视化结果。线性回归是一种基本的机器学习技术,用于对因变量(y)和一个或多个自变量(x)之间的关系进行建模。
🎯 任务
在这个项目中,你将学习:
- 如何将给定数据转换为 Numpy 数组以便于操作
- 如何计算线性回归模型的系数,包括斜率(w)和截距(b)
- 如何在散点图上绘制数据点,并在同一图上绘制线性回归直线
🏆 成果
完成这个项目后,你将能够:
- 为线性回归分析准备数据
- 使用 Numpy 函数计算线性回归参数
- 使用 Matplotlib 创建散点图并叠加线性回归直线
- 更好地理解线性回归及其在数据分析和可视化中的实际应用
将数据转换为 NumPy 数组
在这一步中,你将学习如何将给定数据转换为 NumPy 数组以便于操作。
打开 linear_regression_plot.py 文件。
在 linear_plot() 函数开头找到 data 变量。
使用 np.array() 函数将 data 列表转换为 NumPy 数组。
data = np.array(data)
现在,data 变量是一个 NumPy 数组,这将使它在接下来的步骤中更易于处理。
计算线性回归参数
在这一步中,你将学习如何根据数据计算系数 w 和截距 b 的值。
继续在 linear_plot() 函数中,从 NumPy 数组 data 中提取 x 和 y 值。
x = data[:, 0]
y = data[:, 1]
使用 np.polyfit() 函数计算线性回归参数 w 和 b。
w, b = np.polyfit(x, y, 1)
使用 round() 函数将计算出的 w 和 b 值保留两位小数。
w = round(w, 2)
b = round(b, 2)
现在,你已经得到了系数 w 和截距 b 的值,可以在下一步中使用。
绘制线性回归线
在这一步中,你将学习如何绘制样本散点图,并根据计算出的参数在图上绘制拟合线。
继续在 linear_plot() 函数中,使用 plt.subplots() 创建一个新的 Matplotlib 图形和坐标轴。
fig, ax = plt.subplots()
使用 ax.scatter() 函数绘制数据点。
ax.scatter(x, y, label="Data Points")
使用 ax.plot() 函数以及计算出的 w 和 b 值绘制线性回归线。
ax.plot(x, w * x + b, color="red", label=f"Linear Fit: y = {w}x + {b}")
为该图添加标签和图例。
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.legend()
最后,返回 w、b 的值以及 fig 对象。
return w, b, fig
现在,你已经完成了 linear_plot() 函数,该函数对给定数据执行线性回归,并返回系数和 Matplotlib 绘图对象。
运行线性回归绘图
在这最后一步中,你将学习如何执行你的脚本,以便看到线性回归绘图的实际效果。代码的这一部分将使用先前在 linear_plot() 中定义的函数来显示散点图以及线性回归线。
提供的 Python 代码片段将检查该脚本是否作为主程序运行,如果是,它将调用 linear_plot() 函数并显示生成的绘图。
if __name__ == "__main__":
## 调用 linear_plot 函数来计算回归参数并生成绘图
w, b, fig = linear_plot()
## 显示绘图
plt.show()
在这里,if __name__ == "__main__": 确保只有当脚本直接运行时才调用 linear_plot() 函数,而不是在作为模块导入时调用。在调用返回斜率 w、截距 b 和图形对象 fig 的 linear_plot() 之后,使用 plt.show() 在窗口中显示绘图。这使你能够直观地检查回归线对数据的拟合情况。
现在你可以按下 linear_regression_plot.py 第一行顶部的“运行单元格”按钮并查看结果。

通过这种设置,你可以轻松地使用不同的数据集重新运行脚本,或者对回归计算进行调整,以立即获得视觉反馈。
总结
恭喜你!你已经完成了这个项目。你可以在 LabEx 中练习更多实验来提升你的技能。



