Series · ML Math Derivations · Chapter 14

机器学习数学推导(十四):变分推断与变分EM

从一行恒等式出发推导变分推断:ELBO 分解、平均场假设、坐标上升 CAVI、变分 EM,以及让 VAE 得以训练的重参数化技巧。

后验 $p(\mathbf{z}\mid\mathbf{x})$ 算不出来时,你有两条路。采样路线(MCMC)让一条马尔可夫链以后验为平稳分布慢慢走,理论上渐近精确,但收敛慢、难诊断。变分路线(VI)则换个角度:先选一个简单的分布族 $\mathcal{Q}$,再在其中找到与真实后验最接近的那个 $q^\star$。推断变成了优化——同样的工具链既能训练神经网络,也能拟合贝叶斯模型。

这篇笔记从一个恒等式(ELBO)出发,推出平均场假设和坐标上升 CAVI 算法,再把 EM 和变分 EM 看作它的两种特例,最后讲清楚 VAE 背后的那个让 ELBO 变成可微目标的重参数化技巧。

你将看到什么

  • ELBO 恒等式:变分推断为什么本质上是一个优化问题
  • 平均场假设和它带来的闭式坐标上升(CAVI)更新
  • 变分 EM 怎么把经典 EM 的 E 步换成对 $q$ 的优化
  • 重参数化技巧:低方差、可反向传播的 ELBO 梯度估计
  • 变分推断什么时候适用——以及它什么时候会偷偷骗你(mode-seeking、低估方差)

前置知识

  • 第十三篇 的 EM 算法和 ELBO
  • KL 散度、Jensen 不等式
  • 多元微积分、指数族
  • 随机梯度估计的基本概念

1. 后验为什么算不出来

给定观测 $\mathbf{x}$、隐变量 $\mathbf{z}$、参数 $\boldsymbol{\theta}$,贝叶斯推断要算的是后验

$$p(\mathbf{z}\mid\mathbf{x}) \;=\; \frac{p(\mathbf{x},\mathbf{z})}{p(\mathbf{x})},\qquad p(\mathbf{x}) \;=\; \int p(\mathbf{x},\mathbf{z})\,d\mathbf{z}.$$

分子很便宜——就是模型的联合密度。难的是分母,那个证据(evidence)$p(\mathbf{x})$ 要在整个隐空间上积分。除非模型刚好是共轭指数族,否则这个积分基本算不动。

主流的两条路线对比起来很清楚:

MCMC变分推断
思路随机采样确定性优化
渐近行为精确(极限意义下)有偏,受 $\mathcal{Q}$ 限制
计算成本高,链的混合主导低,几步梯度
收敛诊断难(自相关、$\hat R$)容易(ELBO 单调)
大数据扩展性不好,要专门改造天然支持小批量

VI 用偏差换速度。后面讲的全部都是这种权衡的工程化处理。


2. ELBO 恒等式

任取一个隐变量上的分布 $q(\mathbf{z})$,都成立:

$$ \log p(\mathbf{x}) \;=\; \underbrace{\mathbb{E}_q\!\left[\log\frac{p(\mathbf{x},\mathbf{z})}{q(\mathbf{z})}\right]}_{\displaystyle \mathcal{L}(q)\;\text{(ELBO)}} \;+\; \underbrace{\mathrm{KL}\!\big(q(\mathbf{z})\,\big\|\,p(\mathbf{z}\mid\mathbf{x})\big)}_{\displaystyle \geq 0}. $$

推导只要一行:把联合乘除一个 $q$,取对数,分成两块:

$$\log p(\mathbf{x}) = \log\!\int q(\mathbf{z})\,\frac{p(\mathbf{x},\mathbf{z})}{q(\mathbf{z})}\,d\mathbf{z} = \mathbb{E}_q\!\left[\log\frac{p(\mathbf{x},\mathbf{z})}{q(\mathbf{z})}\right] + \mathrm{KL}(q\,\|\,p(\cdot\mid\mathbf{x})).$$

由于 $\log p(\mathbf{x})$ 不依赖 $q$,对 $q$ 最大化 ELBO 等价于最小化 $q$ 到真后验的 KL

ELBO 分解:log 证据 = ELBO + KL gap

图 1. 数据和模型一旦确定,$\log p(\mathbf{x})$ 就是个常数。ELBO 每涨一点,KL gap 就同步缩小一点;推断完美时这个 gap 归零。又因为 KL 非负,ELBO 是边缘似然的合法下界——这正是它能用于模型比较的原因。

值得停下来想一下我们做了什么。我们把一个无法计算的积分(证据),换成了一个在函数空间里求最优的优化问题(找最好的 $q$)。变分推断后续的全部工程化,都是在执行这次替换。


3. 平均场近似

但优化对象是个无限维分布。最简单、也是 Jordan 等人 90 年代推广开的限制方式叫平均场:让 $q$ 在各坐标上分解成乘积,

$$q(\mathbf{z}) \;=\; \prod_{j=1}^{M} q_j(z_j).$$

每个因子 $q_j$ 自己住在某个分布族里,不再做更具体的形式假设。

平均场把联合压成独立边缘

图 2. 中间那张图匹配了真后验(左图)的边缘,但是把所有协方差都丢掉了。右图给出的是 $\mathrm{KL}(q\|p)$ 的全局最优解:方差按 $1-\rho^2$ 的比例缩小,所以这个近似在两个轴上都低估了不确定性。这是平均场 VI 最常见、也最隐蔽的失败模式。

3.1 最优因子的形式

把分解代入 ELBO,盯住其中一个因子 $q_j$,把其他因子当成已知:

$$ \mathcal{L}(q) = \int q_j(z_j) \,\mathbb{E}_{q_{-j}}\!\big[\log p(\mathbf{x},\mathbf{z})\big]\,dz_j \;-\; \int q_j(z_j)\log q_j(z_j)\,dz_j \;+\; \text{常数}. $$

中括号里的期望是 $z_j$ 的函数。配合归一化约束 $\int q_j = 1$ 做变分,最优解是

$$\boxed{\;\log q_j^\star(z_j) \;=\; \mathbb{E}_{q_{-j}}\!\big[\log p(\mathbf{x},\mathbf{z})\big] \;+\; \text{常数}.\;}$$

这就是平均场 VI 的核心公式:第 $j$ 个坐标的最优因子,是联合密度对其他所有因子取几何平均之后再归一化。

3.2 坐标上升变分推断(CAVI)

每个因子的最优解都依赖其他因子,所以循环求解:

初始化 q_1, ..., q_M
重复
    for j = 1, ..., M:
        log q_j(z_j) <- E_{q_{-j}}[log p(x, z)] + 常数
        归一化 q_j
直到 ELBO 变化 < 阈值

每一步都是对一个逐坐标凹的目标做坐标上升,所以 ELBO 单调不降,必收敛到局部最优;至于是不是全局最优,那就不好说了。

CAVI 椭圆逐步收紧、ELBO 单调上升

图 5. 在一个相关二维高斯上跑了 8 步 CAVI。对角的 $q$ 既被拖到原点,又被压扁——方差最终塌缩到 $1/\text{precision}_{ii}$。右图给出的是 ELBO 的单调轨迹,这是任何 VI 实现里最好用的收敛诊断。

3.3 共轭指数族

当模型是共轭指数族(每个条件 $p(z_j\mid \mathbf{z}_{-j},\mathbf{x})$ 都属于指数族)时,CAVI 更新有闭式解:最优 $q_j$ 与那个条件分布同族,更新它无非是在 $q_{-j}$ 下平均自然参数。贝叶斯高斯混合、LDA、贝叶斯线性回归、带 Dirichlet 先验的 HMM 都属于这一类。非共轭模型则要走第 6 节的黑盒路线。


4. 把变分族当成函数逼近器

跳出平均场,从更广义的视角看 VI 也很有启发:选任何参数族 $q_\phi(\mathbf{z})$——可学均值与协方差的高斯、归一化流、共享参数的推断网络——然后对 $\phi$ 最小化 $\mathrm{KL}(q_\phi\,\|\,p(\cdot\mid\mathbf{x}))$ 就行。

双峰目标上的反向 KL vs 正向 KL

图 3. 在双峰目标上,反向 KL 的最优解(紫色)只锁定一个峰——这是 mode-seeking 的几何根源:$\mathrm{KL}(q\|p)$ 里只要 $q>0$ 而 $p\approx 0$ 就会被罚到无穷。正向 KL 的最优解(橙色,EP 算法用)则去匹配前两阶矩、把两个峰都覆盖住,哪怕在两峰之间的低谷处放些质量也在所不惜。VI 选反向 KL 是因为它只需要 $q$ 的样本就能算——但代价是这种不对称会真切地影响下游的不确定性估计。

KL 不对称性:zero-forcing vs mass-covering

图 6. 同一现象在对称双峰上更明显:反向 KL 有多个局部最优,每个解都钉死在一个峰上;正向 KL 给出唯一解,居中、覆盖性强。设计 VI 系统时,先想清楚下游应用能容忍哪种行为,是个能省掉很多调试时间的习惯。

平均场什么时候会咬人

图 2 的方差低估不是 bug,而是反向 KL 加分解族这套组合的几何必然。只要下游任务用到了后验方差(贝叶斯模型平均、置信预测、决策论),平均场 VI 就会低估不确定性。三种常见的逃逸方式:

  1. 结构化变分:保留部分依赖(比如树结构 $q$);
  2. 归一化流:用一串可逆变换让 $q$ 有能力描述相关性;
  3. 摊还推断(amortized):拿个灵活的神经网络当编码器,VAE 就是这个思路。

5. 变分 EM

第十三篇 其实已经在用 ELBO 恒等式:经典 EM 也是它的产物。EM 交替做两件事:

  • E 步:取 $q(\mathbf{z}) = p(\mathbf{z}\mid\mathbf{x};\boldsymbol{\theta}^{(t)})$,也就是精确后验——KL gap 归零,ELBO 顶到 $\log p(\mathbf{x};\boldsymbol{\theta}^{(t)})$;
  • M 步:固定 $q$,对 $\boldsymbol{\theta}$ 最大化 ELBO,等价于最大化 $\mathbb{E}_q[\log p(\mathbf{x},\mathbf{z};\boldsymbol{\theta})]$。

精确后验算不出来时,E 步就崩了。变分 EM 把它换成一次 VI 内循环:

步骤标准 EM变分 EM变分贝叶斯 EM
E 步精确 $p(\mathbf{z}\mid\mathbf{x};\boldsymbol{\theta})$$\mathcal{Q}$ 中最优 $q(\mathbf{z})$$\mathcal{Q}$ 中最优 $q(\mathbf{z},\boldsymbol{\theta})$
M 步最大化 $Q(\boldsymbol{\theta})$最大化 $Q(\boldsymbol{\theta})$折叠进 VI
$\boldsymbol{\theta}$ 视作点估计点估计随机变量

变分 EM 的 E 步不再让 ELBO 顶到对数似然(KL gap 不为零),所以它优化的是 $\log p(\mathbf{x};\boldsymbol{\theta})$ 的下界而非似然本身。但单调性仍然成立——这在工程上是相当宝贵的保证。

变分贝叶斯 EM(VBEM)更彻底:把 $\boldsymbol{\theta}$ 也当成随机变量、配上自己的变分因子,一次性给出 $\mathbf{z}$ 与 $\boldsymbol{\theta}$ 的联合后验。变分贝叶斯 GMM、变分 LDA 解的就是这个版本。


6. 黑盒变分推断与重参数化技巧

模型不共轭时,CAVI 闭式解和变分 E 步都没了。黑盒 VI(BBVI) 用神经网络参数化 $q_\phi$,直接对 ELBO 做随机梯度优化:

$$\nabla_\phi\,\mathcal{L}(\phi) \;=\; \nabla_\phi\,\mathbb{E}_{q_\phi(\mathbf{z})}\!\left[\log p(\mathbf{x},\mathbf{z}) - \log q_\phi(\mathbf{z})\right].$$

外层期望的测度本身依赖 $\phi$,梯度无法白白搬到期望里——这是所有难点的根源。

得分函数估计器(REINFORCE):把对密度的梯度提出来,

$$\nabla_\phi \mathcal{L} \;=\; \mathbb{E}_{q_\phi}\!\left[\big(\log p(\mathbf{x},\mathbf{z}) - \log q_\phi(\mathbf{z})\big)\nabla_\phi \log q_\phi(\mathbf{z})\right].$$

无偏,但方差极大,必须搭配控制变量和较大样本量才能用。

重参数化技巧:只要能把 $\mathbf{z}$ 写成与参数无关的噪声 $\boldsymbol{\epsilon}$ 的确定性函数,

$$\mathbf{z} \;=\; g_\phi(\boldsymbol{\epsilon}),\qquad \boldsymbol{\epsilon}\sim p(\boldsymbol{\epsilon}),$$

期望的测度就跟 $\phi$ 无关了,梯度可以光明正大地塞进去:

$$\nabla_\phi \mathcal{L} \;=\; \mathbb{E}_{p(\boldsymbol{\epsilon})}\!\left[\nabla_\phi\big(\log p(\mathbf{x},g_\phi(\boldsymbol{\epsilon})) - \log q_\phi(g_\phi(\boldsymbol{\epsilon}))\big)\right].$$

对角高斯 $q_\phi(\mathbf{z}) = \mathcal{N}(\boldsymbol{\mu}_\phi,\,\mathrm{diag}(\boldsymbol{\sigma}_\phi^2))$ 的变换是 $g_\phi(\boldsymbol{\epsilon}) = \boldsymbol{\mu}_\phi + \boldsymbol{\sigma}_\phi\odot\boldsymbol{\epsilon}$,$\boldsymbol{\epsilon}\sim\mathcal{N}(\mathbf{0},\mathbf{I})$。整个计算图都可微,autodiff 接管剩下的事,梯度方差比 REINFORCE 低几个数量级。这就是变分自编码器,以及几乎所有现代连续隐变量变分模型背后的引擎。

离散隐变量上重参数化失效(你写不出 $z\in\{0,1\}$ 关于 $\epsilon$ 的光滑函数)。常见替代:带控制变量的 REINFORCE、Gumbel-Softmax/Concrete 松弛,或先连续松弛再 straight-through 估计。

怎么选

VI 与 MCMC 的速度-精度权衡

图 4. 一个典型的速度-精度图。VI 很快达到一个不错的误差,然后停在它的偏差地板上($\mathcal{Q}$ 与真后验之间的距离)。MCMC 经过 burn-in 后按经典 $1/\sqrt{T}$ 速率向零误差收敛。决策矩阵很短:

  • 大数据、在线学习、探索性建模、神经网络内部:用 VI;
  • 后验真的很重要(药效响应、科研结论、要发论文的那种):用 MCMC;
  • 两者结合:用 VI 的均值给 HMC 链做暖启动,既享受快速初始化,又能拿到渐近无偏。

7. 应用:变分 LDA

LDA(Blei、Ng、Jordan 2003)是 VI 大规模应用的经典代表。模型把每篇文档的主题比例 $\theta_d$ 和每个主题的词分布 $\beta_k$ 都设成 Dirichlet。后验算不出,但模型属共轭指数族,平均场 CAVI 给出闭式更新:

$$q(\theta,\beta,z) \;=\; \prod_d q(\theta_d\mid\gamma_d) \prod_k q(\beta_k\mid\lambda_k) \prod_{d,n} q(z_{d,n}\mid\phi_{d,n}),$$

其中 $\theta_d$、$\beta_k$ 用 Dirichlet 因子,每个词的主题指派 $z_{d,n}$ 用 categorical 因子。

变分 LDA:每文档主题分布与每主题词分布

图 7. 变分 LDA 的两个主要输出。:八篇文档的主题比例后验均值 $\mathbb{E}_q[\theta_d]$。每篇文档都是四个主题的软混合——文档 1、5 以"机器学习"为主,文档 3 以"金融"为主,依此类推。:每个主题的词分布后验均值 $\mathbb{E}_q[\beta_k]$,每个主题都把概率质量集中在少量"特征词"上。这种可解释结构正是 LDA 在 2000 年代到 2010 年代成为文本分析主力工具的原因。

随机变分推断 SVI(Hoffman 等 2013)把同样的更新扩展到亿级文档:每步采样一个小批量,用自然梯度更新 $\lambda_k$,把图 4 的速度优势用到了极致。


8. 实现:变分贝叶斯 GMM

下面是一个简洁的共轭贝叶斯 GMM 的 CAVI 实现。更新公式背后的数学在 Bishop PRML §10.2,这里更关心循环结构本身的可读性。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
from scipy.special import digamma


class VariationalGMM:
    """平均场变分贝叶斯高斯混合(CAVI 版)。"""

    def __init__(self, K=3, max_iter=100, tol=1e-3):
        self.K = K
        self.max_iter = max_iter
        self.tol = tol

    def fit(self, X):
        N, d = X.shape
        # 弱共轭先验
        alpha0, beta0, nu0 = 1.0, 1.0, float(d)
        m0, W0 = X.mean(0), np.eye(d)

        # 变分参数:先验 + 均匀分配的初始化
        self.alpha = np.full(self.K, alpha0 + N / self.K)
        self.beta = np.full(self.K, beta0 + N / self.K)
        self.nu = np.full(self.K, nu0 + N / self.K)
        self.m = np.array([m0 + 0.1 * np.random.randn(d) for _ in range(self.K)])
        self.W = np.array([W0.copy() for _ in range(self.K)])

        r = np.random.dirichlet([1] * self.K, N)
        for _ in range(self.max_iter):
            r_old = r.copy()

            # ---- E 步:变分责任度  q(z_n) ----
            r = self._update_r(X, N, d)

            # ---- M 步:Dirichlet 与 Normal-Wishart 因子的闭式更新 ----
            N_k = r.sum(0)
            x_bar = (r.T @ X) / N_k[:, None]
            self.alpha = alpha0 + N_k
            self.beta = beta0 + N_k
            self.nu = nu0 + N_k
            self.m = (beta0 * m0 + N_k[:, None] * x_bar) / self.beta[:, None]
            for k in range(self.K):
                diff = X - x_bar[k]
                S = (r[:, k, None] * diff).T @ diff / N_k[k]
                dm = x_bar[k] - m0
                self.W[k] = np.linalg.inv(
                    np.linalg.inv(W0)
                    + N_k[k] * S
                    + beta0 * N_k[k] / (beta0 + N_k[k]) * np.outer(dm, dm)
                )

            if np.max(np.abs(r - r_old)) < self.tol:
                break
        return self

    def _update_r(self, X, N, d):
        """rho_{nk} = exp( E_q[log pi_k] + 0.5 E_q[log|Lambda_k|]
                           - 0.5 E_q[(x-mu_k)^T Lambda_k (x-mu_k)] )"""
        r = np.zeros((N, self.K))
        for k in range(self.K):
            E_lp = digamma(self.alpha[k]) - digamma(self.alpha.sum())
            E_ld = sum(digamma((self.nu[k] + 1 - i) / 2) for i in range(1, d + 1))
            E_ld += d * np.log(2) + np.log(max(np.linalg.det(self.W[k]), 1e-10))
            for n in range(N):
                diff = X[n] - self.m[k]
                r[n, k] = (E_lp
                           + 0.5 * E_ld
                           - 0.5 * (self.nu[k] * diff @ self.W[k] @ diff
                                    + d / self.beta[k]))
        r = np.exp(r - r.max(1, keepdims=True))
        return r / r.sum(1, keepdims=True)

    def predict(self, X):
        return np.argmax(self._update_r(X, len(X), X.shape[1]), axis=1)

算法骨架与第十三篇的 EM-GMM 相似——交替更新责任度和成分统计量——但每个量都是 $q$ 下的后验矩,而非点估计。一个附带的好处是自动剪枝:空成分的 $\alpha_k$ 会向先验 $\alpha_0$ 收缩,没用的主题会自动淡出,不会硬去拟合噪声。


9. 常见问题

Q1:为什么用反向 KL 而不是正向 KL? 反向 $\mathrm{KL}(q\|p)$ 只要在 $q$ 下取期望,这正是我们能控制的;正向 $\mathrm{KL}(p\|q)$ 要在算不出来的 $p$ 下取期望。代价就是 mode-seeking 的行为——见图 6。

Q2:我的 ELBO 在下降,是不是哪里坏了? 是。CAVI 的 ELBO 按构造单调不降。掉下来八成是 bug,常见原因:(i) 责任度归一化搞错;(ii) 熵项的符号写反;(iii) 在用某个变分参数时,它还没被这一轮更新过。

Q3:ELBO 这个下界紧吗? $\log p(\mathbf{x}) - \mathcal{L}(q^\star) = \mathrm{KL}(q^\star\|p(\cdot\mid\mathbf{x})) \geq 0$。这个 gap 正好是变分族本身的建模误差。所以基于 ELBO 做模型比较,只在不同模型 gap 量级可比时才公允。

Q4:平均场什么时候会"灾难性"失效? 后验有强相关(图 2)和多峰(图 6)。症状:方差被严重低估、预测过度自信、对初始化高度敏感。补救:结构化 VI、归一化流,或全协方差摊还推断。

Q5:重参数化 vs REINFORCE,怎么选? 连续可微的 $q$(高斯、Logistic、用 Gumbel-Softmax 包装的混合)用重参数化;离散隐变量用 REINFORCE。两者都适用时,重参数化的梯度方差通常低 100~1000 倍。

Q6:变分 EM 和完整变分贝叶斯(VBEM),用哪个? 需要 $\boldsymbol{\theta}$ 的不确定性(小数据、模型选择、决策论)→ 用 VBEM;只需要好的点估计(大数据、追求预测精度)→ 变分 EM 更快更省事。


10. 习题

E1. 从头证明 $\mathcal{L}(q) \leq \log p(\mathbf{x})$。 提示. $\log p(\mathbf{x}) = \mathcal{L}(q) + \mathrm{KL}(q\|p(\cdot\mid\mathbf{x}))$,KL 由 Jensen 不等式非负即得。

E2. 对二变量平均场 $q(z_1,z_2) = q_1(z_1)q_2(z_2)$,写出最优 $q_1^\star$。 答案. $\log q_1^\star(z_1) = \mathbb{E}_{q_2}[\log p(\mathbf{x},z_1,z_2)] + \text{常数}$。

E3. 推导一个零均值二维高斯(精度矩阵 $\Lambda$)的 CAVI 闭式更新;证明收敛后 $\mathrm{Var}_q[z_j] = 1/\Lambda_{jj}$——是精度矩阵的对角元的倒数,而不是协方差矩阵的对角元。 提示. 复现图 5 的轨迹。

E4. 证明变分 EM 的 ELBO 在迭代间单调不降,即使 E 步不再精确。 提示. E 步提升 ELBO(更新 $q$),M 步提升 ELBO(更新 $\boldsymbol{\theta}$),交替结构保持单调性。

E5. 为什么直接用 $\nabla_\phi \mathbb{E}_{q_\phi}[f(\mathbf{z})] \approx \frac{1}{S}\sum_s \nabla_\phi f(\mathbf{z}^{(s)})$($\mathbf{z}^{(s)}\sim q_\phi$)会算错? 答案. 样本 $\mathbf{z}^{(s)}$ 本身依赖 $\phi$,而朴素梯度忽略了这种依赖。REINFORCE 和重参数化是两种正确处理这一依赖的方式。


参考文献

  • Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine Learning, 37(2), 183–233.
  • Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational inference: A review for statisticians. JASA, 112(518), 859–877.
  • Hoffman, M. D., Blei, D. M., Wang, C., & Paisley, J. (2013). Stochastic variational inference. JMLR, 14, 1303–1347.
  • Kingma, D. P., & Welling, M. (2014). Auto-encoding variational Bayes. ICLR.
  • Ranganath, R., Gerrish, S., & Blei, D. M. (2014). Black box variational inference. AISTATS.
  • Bishop, C. M. (2006). Pattern Recognition and Machine Learning, 第 10 章。Springer.

系列导航

Liked this piece?

Follow on GitHub for the next one — usually one a week.

GitHub