强化学习(十):离线强化学习

离线强化学习从固定数据集学习策略,无需任何在线交互。本文系统讲解分布偏移、外推误差,以及 CQL、BCQ、IQL、Decision Transformer 四类主流方法,配有完整的 CQL PyTorch 实现与 D4RL 基准对比。

到目前为止,我们学过的每一个 RL 算法都依赖同一个循环:行动—观察—更新。这个循环让 RL 能够工作,但也让它在很多场景里根本无法落地。自动驾驶不能靠真实撞车来学习路口的处理;医疗决策模型不能在病人身上跑随机策略;产线上的机器人没有几千次失败抓取可以浪费。

但是这些场景里都有一样东西不缺——日志:上百万小时的人类驾驶数据、几十年的脱敏病历、TB 级的行为克隆数据。离线强化学习(Offline RL,又称 Batch RL)的问题是:在完全不与环境交互的前提下,能不能从一个固定数据集里训练出一个强策略?

答案是"可以,但前提是足够小心"。这个"小心"的来源,就是本文的核心主题——产生数据的行为策略与我们想要学的目标策略之间的分布偏移

你将学到什么

  • 为什么把任何一个标准 off-policy 算法直接搬到离线设定下都会失败:外推误差、价值高估、和"死亡螺旋"。
  • CQL(Conservative Q-Learning):用一个悲观正则项把 Q 值压成真实价值的下界。
  • BCQ(Batch-Constrained Q-Learning):用 VAE 限制策略只能从行为分布里挑动作。
  • IQL(Implicit Q-Learning):用期望分位数回归彻底绕开 OOD 动作。
  • Decision Transformer:把 RL 重写成"按目标回报做条件的序列建模"。
  • 在 D4RL 基准上,这几类方法各自适合什么样的数据
  • 一份可直接运行的 CQL PyTorch 实现。

前置知识

  • Q-learning、目标网络、Actor-Critic(第 2 部分第 6 部分 )。
  • 熟悉 Bellman 最优算子和重要性采样。
  • PyTorch、Gym/Gymnasium 接口。

1. 离线 RL 真正难在哪

在线 vs 离线 RL:缺失的反馈回路

在线 RL 里,策略和数据分布是耦合在一起的。一旦策略对某个动作产生高估,下一轮采样就会暴露这个错误,Bellman 更新会立刻把它压回去。离线 RL 里数据是冻结的,没有第二次机会:模型对未见动作犯下的任何错误都会永远留在 Q 函数里,而 argmax 算子会非常乐意去利用它。

1.1 分布偏移

设产生数据集 $\mathcal{D}=\{(s_i,a_i,r_i,s_i')\}_{i=1}^N$ 的行为策略是 $\pi_\beta$,我们要学的策略是 $\pi_\theta$。两者的状态-动作访问分布通常不一样:

$$d_{\pi_\theta}(s,a)\neq d_{\pi_\beta}(s,a).$$

只要 $\pi_\theta$ 想要选择某个 $\pi_\beta(a\mid s)\approx 0$ 的动作,问题就来了:Q 网络从来没有为这个动作算过 TD 目标,它给出的任何值都纯粹是从一个对其他区域拟合的神经网络外推出来的——而神经网络的外推几乎总是离谱的。

1.2 外推误差与死亡螺旋

分布偏移导致 OOD 动作上的 Q 值被严重高估

Q-learning 的更新公式是

$$Q(s,a)\leftarrow r + \gamma\,\max_{a'}Q(s',a').$$

max 算子专门挑选最乐观的那一个外推值。哪怕 Q 函数只对一个 OOD 动作高估了 $\epsilon$,这个偏差就会成为前一时刻的 bootstrap 目标,再成为更前一时刻的目标……一路传播回去。在标准基准上,这种发散通常几千步就能观察到Fujimoto 等, 2019 )。上图右侧那张图就是要记住的画面:绿色是真实 Q 函数,红色是网络相信的 Q 函数,策略会沿着 OOD 那个虚假的山峰一路走到悬崖外。

后面要讲的三类算法都在攻击这个问题,区别只在于它们究竟约束了什么

类别一句话总结约束对象
策略约束(BCQ、BEAR、TD3+BC)“只在数据见过的地方动作”$\pi_\theta$
价值悲观(CQL、MOPO)“对没见过的 Q 值不要相信”$Q$
样本内学习(IQL、AWAC)“压根不去查 OOD 动作”损失函数本身
序列建模(Decision Transformer 等)“干脆绕过 Bellman”问题的整个范式

2. CQL:保守 Q-Learning

CQL(Kumar 等, 2020 )是目前业界最常用的离线 RL 基线。它不改网络结构、不改 actor、不改数据 pipeline,只在损失里加一项。

2.1 一句话原理

把策略可能选择的所有动作的 Q 值往下压;再把数据里出现过的动作的 Q 值拉回来。 净效果是:Q 函数恰好在"它没有证据的地方"显得悲观。

2.2 目标函数

在普通 TD 损失 $\mathcal{L}_{\mathrm{TD}}$ 之外,CQL 多加一项:

$$\mathcal{L}_{\mathrm{CQL}} \;=\; \alpha\,\Big[\,\underbrace{\log\sum_{a}\exp Q(s,a)}_{\text{压低所有动作}} \;-\; \underbrace{\mathbb{E}_{a\sim\mathcal{D}}\big[Q(s,a)\big]}_{\text{把数据动作拉回来}}\Big] \;+\; \mathcal{L}_{\mathrm{TD}}.$$

logsumexp 是一个对动作的"软最大"。最小化它会把所有动作的 Q 值往下拽。再减去数据分布下的期望,相当于把"在数据里出现过"的动作浮回原位。OOD 动作只感受到向下的力。

CQL 的核心定理(原文 Thm 3.2 )说,当 $\alpha$ 足够大时,

$$\hat{Q}^{\pi}(s,a)\;\leq\;Q^{\pi}(s,a)\quad\forall (s,a).$$

一个真实价值的下界正是我们想要的:任何最大化下界的策略都不可能选出灾难性动作,因为对它来说那种动作没有任何"看上去的好处"。

CQL:正则项把 OOD 区域的 Q 值往下压,让 argmax 留在数据支撑集里

左图是普通 SAC critic 在离线数据上的表现:argmax 直接逃到 OOD 那个虚假高峰上去了。右图是 CQL 修正后的样子:橙色阴影是悲观惩罚的部分,新的 argmax 安静地落在数据支撑集里面。

2.3 工程上的几个细节

  • 连续动作下 logsumexp 没法精确算,实现里通常用 n_random 个均匀样本加 n_actor 个当前策略的样本来近似,10~20 个就够。
  • 原论文里的 “CQL($\mathcal{H}$)” 变体加了一个拉格朗日项来自动调 $\alpha$,让差值 $\mathbb{E}_{a\sim\pi}Q-\mathbb{E}_{a\sim\mathcal{D}}Q$ 维持在某个目标附近——d3rlpyJaxRL 默认就是这个版本。
  • 经验上 CQL 在 D4RL 的 medium-replay / medium 上很稳,但在 medium-expert 上略偏保守,那里 IQL 或 DT 通常会赢。

2.4 一份完整的 CQL 实现

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
import torch.nn as nn
import torch.nn.functional as F


class QNetwork(nn.Module):
    """孪生 Q 网络(Clipped Double Q),用来抑制高估。"""

    def __init__(self, state_dim, action_dim, hidden=256):
        super().__init__()
        self.q1 = self._mlp(state_dim + action_dim, hidden)
        self.q2 = self._mlp(state_dim + action_dim, hidden)

    @staticmethod
    def _mlp(in_dim, hidden):
        return nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, 1),
        )

    def forward(self, state, action):
        sa = torch.cat([state, action], dim=1)
        return self.q1(sa), self.q2(sa)


class GaussianPolicy(nn.Module):
    """SAC 风格的 tanh-squashed 高斯策略。"""

    def __init__(self, state_dim, action_dim, hidden=256):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Linear(state_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
        )
        self.mean = nn.Linear(hidden, action_dim)
        self.log_std = nn.Linear(hidden, action_dim)

    def sample(self, state):
        x = self.trunk(state)
        mean = self.mean(x)
        log_std = self.log_std(x).clamp(-20, 2)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        z = normal.rsample()
        action = torch.tanh(z)
        # tanh 变换的雅可比项
        log_prob = (normal.log_prob(z)
                    - torch.log(1 - action.pow(2) + 1e-6)).sum(1, keepdim=True)
        return action, log_prob


class CQLAgent:
    """在 SAC 之上加一个 CQL 正则项。cql_weight 就是公式里的 alpha。"""

    def __init__(self, state_dim, action_dim, lr=3e-4,
                 gamma=0.99, tau=0.005, alpha=0.2,
                 cql_weight=1.0, n_random=10):
        self.gamma, self.tau = gamma, tau
        self.cql_weight, self.n_random = cql_weight, n_random
        self.action_dim = action_dim

        self.q_net = QNetwork(state_dim, action_dim)
        self.target_q = QNetwork(state_dim, action_dim)
        self.target_q.load_state_dict(self.q_net.state_dict())
        self.policy = GaussianPolicy(state_dim, action_dim)

        self.q_opt = torch.optim.Adam(self.q_net.parameters(), lr=lr)
        self.pi_opt = torch.optim.Adam(self.policy.parameters(), lr=lr)
        self.alpha = alpha

    # ---------- 与普通 SAC 唯一不同的地方就在这里 ------------
    def _cql_penalty(self, states, q1_data, q2_data):
        """对动作做 logsumexp,再减去数据上的 Q 平均。"""
        b = states.size(0)
        # (1) [-1, 1] 上的均匀随机动作
        rand_a = torch.empty(b * self.n_random, self.action_dim).uniform_(-1, 1)
        # (2) 当前策略在 s 处采样的动作
        rep_s = states.repeat_interleave(self.n_random, 0)
        pi_a, pi_lp = self.policy.sample(rep_s)

        def _q(s, a):
            return self.q_net(s, a)

        q1_rand, q2_rand = _q(rep_s, rand_a)
        q1_pi,   q2_pi   = _q(rep_s, pi_a)

        # 重要性加权 logsumexp(Kumar 等, Eq. 4)
        log_pi_uniform = -float(self.action_dim) * torch.log(torch.tensor(2.0))
        cat1 = torch.cat([q1_rand - log_pi_uniform,
                          q1_pi   - pi_lp.detach()], 0).view(b, -1)
        cat2 = torch.cat([q2_rand - log_pi_uniform,
                          q2_pi   - pi_lp.detach()], 0).view(b, -1)

        lse1 = torch.logsumexp(cat1, 1, keepdim=True)
        lse2 = torch.logsumexp(cat2, 1, keepdim=True)
        return ((lse1 - q1_data) + (lse2 - q2_data)).mean()

    def update(self, states, actions, rewards, next_states, dones):
        # 1. Bellman 目标(clipped double Q + 熵奖励)
        with torch.no_grad():
            next_a, next_lp = self.policy.sample(next_states)
            tq1, tq2 = self.target_q(next_states, next_a)
            target_q = torch.min(tq1, tq2) - self.alpha * next_lp
            target = rewards + self.gamma * (1 - dones) * target_q

        q1, q2 = self.q_net(states, actions)
        td_loss = F.mse_loss(q1, target) + F.mse_loss(q2, target)
        cql = self._cql_penalty(states, q1, q2)
        q_loss = td_loss + self.cql_weight * cql

        self.q_opt.zero_grad()
        q_loss.backward()
        self.q_opt.step()

        # 2. Actor 更新(标准 SAC)
        a, lp = self.policy.sample(states)
        q1_a, q2_a = self.q_net(states, a)
        pi_loss = (self.alpha * lp - torch.min(q1_a, q2_a)).mean()
        self.pi_opt.zero_grad()
        pi_loss.backward()
        self.pi_opt.step()

        # 3. 目标 Q 的 Polyak 平均
        with torch.no_grad():
            for p, tp in zip(self.q_net.parameters(),
                             self.target_q.parameters()):
                tp.data.mul_(1 - self.tau).add_(self.tau * p.data)

        return {"td": td_loss.item(),
                "cql": cql.item(),
                "pi":  pi_loss.item()}

接一个离线 replay buffer(用 minarid4rl-pybullet 加载 D4RL 数据),训练大约 1M 梯度步,应该就能复现下面基准图里的数字。


3. BCQ:用生成模型限制候选动作

CQL 约束的是价值;BCQ(Fujimoto 等, 2019 )直接约束策略本身。它的思路非常直接:永远不要在行为策略不会产生的动作上查询 Q 函数

BCQ:VAE 提候选,扰动网络做小幅修正,最后用 argmax 选一个

整套架构有三块:

  1. 条件 VAE $G_\omega(s)$:在 $\mathcal{D}$ 上用标准 ELBO 训练,建模 $\pi_\beta(a\mid s)$。从它采样得到的动作,天然就在行为流形上。
  2. 扰动网络 $\xi_\phi(s,a)\in[-\Phi,\Phi]^{\dim(a)}$:对每个 VAE 候选加一个小幅、有界的修正。$\Phi$ 取小(一般 0.05)保证候选不会偏离数据;$\Phi=0$ 时 BCQ 就退化成加权模仿。
  3. 孪生 Q 网络:对 $N$ 个修正后的候选打分,argmax 决定真正执行的动作。

精妙之处在于,扰动网络的目标是最大化 Q——所以 BCQ 既享受到了策略改进的红利(不止是模仿),又避免了 OOD 外推(候选动作动不出数据流形)。代价是 pipeline 比较复杂——四个网络,而且 VAE 必须训得好,否则候选本身就有偏。


4. IQL:从根上避开 bootstrap 问题

IQL(Kostrikov 等, 2022 )是这几个方法里最优雅的。它的洞见非常干脆:离线 RL 的所有毛病都来自 Bellman 目标里的 $\max_{a'}Q(s',a')$,因为 OOD 动作就是从这里钻进来的。那……干脆别用这个 max。

IQL 单独学一个状态价值函数 $V(s)$,只在数据里出现过的状态-动作对上做期望分位数回归

$$\mathcal{L}_V \;=\; \mathbb{E}_{(s,a)\sim\mathcal{D}}\big[L_2^{\tau}\big(Q(s,a)-V(s)\big)\big],\qquad L_2^{\tau}(u)=\big|\tau-\mathbb{1}(u<0)\big|\,u^{2}.$$

$\tau=0.5$ 时这就是普通 MSE,$V$ 学到的是均值。$\tau\to 1$ 时它对"低估"的惩罚远大于对"高估"的惩罚,于是 $V$ 收敛到 $Q(s,a)$ 在数据里出现过的动作上的上分位数

IQL 的非对称期望分位数损失:调高 tau 会让 V 逼近 max_a Q,但全程不离开数据

当 $\tau\approx 0.7$~$0.9$ 时,$V(s)$ 已经是 $\max_a Q(s,a)$ 的极好代理——而它的计算只用到了数据里出现过的动作。Bellman 目标变成

$$y \;=\; r + \gamma\,V(s'),$$

没有 max、没有策略采样、没有外推。Actor 通过优势加权回归(Advantage-Weighted Regression)恢复出来:

$$\mathcal{L}_\pi \;=\; -\mathbb{E}_{(s,a)\sim\mathcal{D}}\big[\exp\big(\beta\,(Q(s,a)-V(s))\big)\,\log\pi_\theta(a\mid s)\big].$$

IQL 是当下离线 RL 算法里"动得最少"的一个,在 D4RL 的 AntMaze、Adroit 等"好数据稀疏 + max 最危险"的环境里持续霸榜。


5. Decision Transformer:把 RL 当成序列建模

如果都已经放弃 Bellman 方程了,何必停在 IQL 这一步?Decision TransformerChen 等, 2021 )连价值函数都不要了,直接把 RL 改写成"下一个 token 预测"的问题。

Decision Transformer:以 (return-to-go, state, action) 三元组为 token,喂给一个因果 Transformer

把一条轨迹排成 $(\hat{R}_1, s_1, a_1, \hat{R}_2, s_2, a_2, \ldots)$ 的序列,其中 $\hat{R}_t = \sum_{t'\geq t} r_{t'}$ 是从 $t$ 时刻起的剩余回报(return-to-go)。一个标准 GPT 风格的因果 Transformer 用交叉熵或 MSE 学着在已有 token 的条件下预测 $a_t$。

测试时把你期望的剩余回报放进 $\hat{R}_1$,然后让模型自回归往下生成动作。$\hat{R}_1$ 就像一个旋钮:要更激进的策略,就给一个更大的目标回报。

这条路线的好处:

  • 没有 bootstrap,自然就没有外推误差。
  • 长上下文让模型"免费"获得部分可观测下的历史条件。
  • Transformer 的所有工程红利都直接拿来用(LayerNorm、RoPE、FlashAttention、混合精度)。
  • 一套范式从 D4RL 一路 scale 到 Atari 再到大规模多任务(如 Gato)。

代价:

  • 永远超不过数据里出现过的最佳回报:如果 $\mathcal{D}$ 里没有任何一条轨迹得到过 90 分,你向模型要 90 分,得到的就是垃圾。
  • 没有显式的因果信用分配——它学的是"在某个回报条件下模仿轨迹",而不是"哪个动作导致了结果"。
  • 在小基准上要追平 CQL/IQL 的分数,通常要更大的模型和更多的数据。

6. 真实数据:D4RL 基准对比

D4RL MuJoCo 运动控制:BC、BCQ、CQL、IQL、DT 在三种数据质量上的表现

上图整理自 CQL(Kumar 等, 2020 )、IQL(Kostrikov 等, 2022 )和 DT(Chen 等, 2021 )原论文里在 D4RL MuJoCo 运动控制套件(hopperhalfcheetahwalker2d 取平均)上的代表性归一化分数。100 表示专家水平,0 表示随机策略。

三个值得记住的结论:

  1. medium-replay 是最考验保守性的数据。 这种数据来自训练早期的 SAC 检查点——动作很差,但有大量"从坏状态恢复"的转移。CQL 和 IQL 在这种数据上能把 BC 的分数翻一倍;DT 因为没有价值 bootstrap,没办法"拼接"不同轨迹里的好片段,所以落后。
  2. medium-expert 是 DT 的主场。 数据里已经有近似最优的轨迹时,序列建模的简单性胜出。
  3. IQL 是最稳的。 它几乎不会在某个单一数据集上拿第一,但也几乎不会跌出前三,且训练过程在四种里最稳定。

如果不知道选哪个:默认选 IQL(最稳)或 CQL(最简单);当数据已经接近专家水平时考虑 DT


7. 常见问题

Q: 离线 RL 在什么情况下会彻底失败? 有三种典型失败模式:(i) 数据覆盖太窄——只有专家轨迹、没有恢复样本,策略学不会处理自己的失误;(ii) 数据质量极低——长程任务上的随机策略数据;(iii) 测试环境漂移——评估 MDP 的转移动力学和产生 $\mathcal{D}$ 的环境不一样。

Q: CQL 的 $\alpha$ 怎么调? 经验范围:专家数据 $\alpha=0.5$~$1.0$,中等数据 $1.0$~$5.0$,随机/replay 数据 $5.0$~$10.0$。更省心的做法是用原论文的拉格朗日变体,让 $\alpha$ 自动追踪一个目标 gap(通常设为 5~10)。

Q: 离线 RL 和模仿学习到底有什么本质差别? 行为克隆(BC)忽略奖励,只能逼近示范者,永远超不过它。离线 RL 利用奖励信号,可以拼接(stitching)不同轨迹里的好片段,得到一个比 $\mathcal{D}$ 里任何单条轨迹都更好的策略。经典的拼接例子:A 轨迹笨拙地走到状态 $s$ 但之后做得很好,B 轨迹高效到达 $s$ 但之后做得很差——离线 RL 可以学出"B 的开头 + A 的结尾",BC 永远学不出来。

Q: 是不是该直接上基于模型的离线方法? 对高维连续控制,MOPO/MOReL/COMBO 也很有竞争力,但需要额外学一个动力学模型并附带它自己的悲观项。它们在数据少的时候有优势(模型提供归纳偏置),但工程成本明显更高。在 2025 年,模型无关的 CQL/IQL 仍是默认起点。

Q: 离线到在线(offline-to-online)微调怎么做? 最近两年 IQL 在这个方向上拉开差距。AWACCal-QLRLPD 等方法用离线策略初始化,然后用少量在线数据继续训练。CQL 初始化往往太保守,反而拖累微调;IQL 或 AWAC 初始化通常效果更好。


参考文献


系列导航

Liked this piece?

Follow on GitHub for the next one — usually one a week.

GitHub