使用 Matplotlib 进行贝叶斯更新

Beginner

This tutorial is from open-source community. Access the source code

简介

贝叶斯更新是一种统计方法,它使我们能够在有新数据可用时更新假设的概率。在本实验中,我们将使用 Matplotlib 创建一个动画,展示贝叶斯更新的工作原理。具体来说,我们将模拟一个抛硬币实验,并使用贝叶斯更新来估计硬币正面朝上的概率。

虚拟机使用提示

虚拟机启动完成后,点击左上角切换到笔记本标签页,以访问 Jupyter Notebook 进行练习。

有时,你可能需要等待几秒钟让 Jupyter Notebook 完成加载。由于 Jupyter Notebook 的限制,操作验证无法自动化。

如果你在学习过程中遇到问题,随时向 Labby 提问。课程结束后提供反馈,我们会及时为你解决问题。

导入所需库

我们首先导入将在本实验中使用的库。具体来说,我们将使用matplotlib.pyplot进行可视化,使用numpy进行数值计算,使用math进行数学函数运算。

import math

import matplotlib.pyplot as plt
import numpy as np

定义贝塔分布的概率密度函数

贝塔分布是一种连续概率分布,常用于表示概率的分布。在贝叶斯更新中,我们使用贝塔分布作为先验分布,以表示在观察任何数据之前我们对假设概率的信念。然后,随着我们观察到新数据,我们会更新贝塔分布。

为了模拟贝叶斯更新,我们需要定义一个函数来计算贝塔分布的概率密度函数(PDF)。我们可以使用math.gamma函数来计算伽马函数,伽马函数在贝塔分布的 PDF 中会用到。

def beta_pdf(x, a, b):
    return (x**(a-1) * (1-x)**(b-1) * math.gamma(a + b)
            / (math.gamma(a) * math.gamma(b)))

定义 UpdateDist 类

接下来,我们定义一个名为UpdateDist的类,用于在观察到新数据时更新贝塔分布。UpdateDist类接受两个参数:Matplotlib 轴对象和成功的初始概率。

class UpdateDist:
    def __init__(self, ax, prob=0.5):
        self.success = 0
        self.prob = prob
        self.line, = ax.plot([], [], 'k-')
        self.x = np.linspace(0, 1, 200)
        self.ax = ax

        ## 设置绘图参数
        self.ax.set_xlim(0, 1)
        self.ax.set_ylim(0, 10)
        self.ax.grid(True)

        ## 这条垂直线代表理论值,绘制的分布应收敛于此。
        self.ax.axvline(prob, linestyle='--', color='black')

__init__方法通过将成功的初始次数设置为零(self.success = 0)以及将成功的初始概率设置为作为参数传递的值(self.prob = prob)来初始化类实例。我们还创建一个线条对象来表示贝塔分布并设置绘图参数。

__call__方法在动画每次更新时被调用。它模拟抛硬币实验并相应地更新贝塔分布。

def __call__(self, i):
        ## 这样绘图就能持续运行,我们只需不断观察该过程的新实现情况
        if i == 0:
            self.success = 0
            self.line.set_data([], [])
            return self.line,

        ## 通过均匀选取并与阈值比较来选择成功情况
        if np.random.rand() < self.prob:
            self.success += 1
        y = beta_pdf(self.x, self.success + 1, (i - self.success) + 1)
        self.line.set_data(self.x, y)
        return self.line,

如果这是动画的第一帧(if i == 0),我们将成功次数重置为零并清除线条对象。否则,我们通过生成一个介于 0 和 1 之间的随机数(np.random.rand())并将其与成功概率(self.prob)进行比较来模拟抛硬币实验。如果随机数小于成功概率,我们将其计为一次成功,并使用beta_pdf函数更新贝塔分布。最后,我们用新数据更新线条对象并返回它。

创建动画

既然我们已经定义了UpdateDist类,那么我们可以使用 Matplotlib 的FuncAnimation类来创建动画。我们创建一个图形对象和一个轴对象,并将轴对象传递给UpdateDist类以创建该类的一个新实例。

fig, ax = plt.subplots()
ud = UpdateDist(ax, prob=0.7)
anim = FuncAnimation(fig, ud, frames=100, interval=100, blit=True)
plt.show()

FuncAnimation类接受几个参数:

  • fig:图形对象
  • udUpdateDist实例
  • frames:要制作动画的帧数
  • interval:帧之间的时间间隔(以毫秒为单位)
  • blit:是否仅更新绘图中已更改的部分

解释结果

动画展示了随着新数据的观察,贝塔分布是如何更新的。黑色虚线代表成功的真实概率(即硬币正面朝上的概率)。随着动画的推进,我们可以看到贝塔分布最初在成功的先验概率(0.7)处有一个峰值,并且随着观察到更多数据,逐渐向成功的真实概率移动。

总结

在本实验中,我们使用 Matplotlib 创建了一个演示贝叶斯更新的动画。我们定义了一个函数来计算贝塔分布的概率密度函数(PDF),以及一个类来在观察到新数据时更新贝塔分布。然后,我们使用 Matplotlib 的FuncAnimation类创建了动画并解释了结果。