SR-GNN —— Session-based Recommendation with Graph Neural Networks
SR-GNN 把一段点击会话拆成有向加权图,再用门控 GNN 做下一跳预测。本文系统讲清会话图构建、GGNN 更新、局部+全局池化、训练细节、基准对比,以及决定要不要在生产里用它的几类失败模式。
用户依次点击 A、B、C、B、D。把它喂给一个序列模型,得到的是五个 token 折叠出的隐状态;交给 SR-GNN,得到的是一张有向图——边 B -> C 即便用户回到 B 也仍然存在,节点 B 只出现一次(它的入边和出邻居都会贡献到它的表示上),整段点击的拓扑都被原样保留在邻接矩阵里。这就是 SR-GNN(Wu 等, AAAI 2019)
在多个会话推荐基准上稳稳压过 GRU4Rec、NARM 等纯序列模型的根本原因。
这篇笔记把 SR-GNN 从头到尾拆开讲:会话图怎么建、门控 GNN(GGNN)在图上做了什么、会话向量为什么要做"局部+全局"两路融合、打分和训练有哪些坑,以及它在哪些情况下会崩——目的不是说服你用它,而是让你能准确判断什么时候不要用它。
你会学到什么
- 怎么从点击流构造一张有向加权会话图,以及入/出邻接矩阵怎么喂给 GGNN
- 门控 GNN 更新可以拆成"先聚合消息、再走 GRU 单元"两步
- 会话池化为什么要做局部+全局,以及注意力为什么要锚定在最后一个点击上
- 训练目标、几个真正合理的超参数,以及在这种小图上 BPTT 究竟意味着什么
- 在带回访、长度中等的会话上,图视角为什么稳赢序列模型
- 几类典型失败模式(短会话、流行度坍塌、冷启动物品、超大物品池)和对应的修法
- 几个值得知道的常见变体(attention 化的 GGNN、时间间隔边、多任务头),以及它们在什么场景下真正有用
前置准备
- 熟悉消息传递与 GNN 基础术语(邻接矩阵、传播步数)
- 了解 GRU/LSTM 的门控结构
- 会用推荐评测指标 Recall@K、MRR@K,知道 sampled softmax 是干什么的
1. 背景:会话推荐到底特殊在哪
会话推荐的限制条件比常见 CF 苛刻得多:没有稳定的长期用户画像可用,能看到的就是当前这一段短点击序列 $s = [v_{s,1}, v_{s,2}, \dots, v_{s,n}]$,物品池为 $V = \{v_1, \dots, v_{|V|}\}$,目标是预测下一跳 $v_{s,n+1}$。模型最终输出在 $|V|$ 上的打分向量 $\hat z$,取 top-$K$ 推荐。
这种设定让传统 CF 和纯 RNN 都不太顺手,原因有两点:
- 上下文极短:典型会话只有 2–10 个点击,模型必须从极少的信号里反推用户当下的"意图"。
- 物品会重复出现,意图也不单调:用户会反复比较,绕回来再点同一个商品。同一物品在一段会话里出现多次是常态,“靠后的点击"也未必比"靠前的点击"代表更稳定的偏好。
纯序列模型把这些点击压成一个隐状态,就不可避免地丢掉了"重复访问的物品之间还有结构信息"这一事实。SR-GNN 的贡献,就是把这种结构原原本本以图的形式保留下来,剩下交给消息传递。
2. 会话图的构建
对每段会话 $s$,SR-GNN 构造有向图 $G_s = (V_s, E_s)$,节点集 $V_s$ 是这段会话里去重后的物品(不论某物品被点击几次,作为节点只出现一次),边集 $E_s$ 包含每一次观察到的转移 $u \to v$。边权按起点的出度归一化:
$$ w_{u \to v} \;=\; \frac{\#(u \to v)}{\mathrm{outdeg}(u)} \, . $$由此自然得到两张邻接矩阵:入度邻接 $A^{(\text{in})}$(第 $i$ 行告诉节点 $i$ 谁会喂信息给它)和出度邻接 $A^{(\text{out})}$(第 $i$ 行告诉节点 $i$ 它会喂信息给谁)。两张矩阵横着拼成 $|V_s| \times 2|V_s|$ 的 $A_s$,作为 GGNN 真正消费的输入。

具体看点击流 A, B, C, B, D 的边权计算:
| 边 | 出现次数 | 起点出度 | 边权 |
|---|---|---|---|
| $A \to B$ | 1 | 1 | 1.00 |
| $B \to C$ | 1 | 2 | 0.50 |
| $C \to B$ | 1 | 1 | 1.00 |
| $B \to D$ | 1 | 2 | 0.50 |
如果直接用 GRU 跑这串序列,第二次访问 B 时新的隐状态会把”B -> C“这条转移悄悄覆盖掉;图的视角则一字不漏地保留下来。
3. 门控 GNN 的传播
每个物品 $v$ 都有一个 $d$ 维向量 $h_v$,从全局物品表 $V \in \mathbb{R}^{|V|\times d}$ 里查出来。SR-GNN 在会话图上跑 $T$ 轮消息传递,用的是 门控图神经网络 GGNN(Li 等, 2016)——本质上就是一个以"聚合后的消息"为输入的 GRU 单元。
第 $t$ 步,节点 $i$ 收到的聚合消息是
$$ a_t^{(i)} \;=\; A_{s,i:}\, \big[h_1^{(t-1)}, \dots, h_n^{(t-1)}\big]^\top\, W_a \;+\; b, $$其中 $A_{s,i:}$ 是拼接邻接 $[A^{(\text{in})}\,|\,A^{(\text{out})}]$ 的第 $i$ 行。然后用 GRU 风格的门做状态更新:
$$ \begin{aligned} z_t &\;=\; \sigma\!\big(W_z\, a_t + U_z\, h_{t-1}\big), \\ r_t &\;=\; \sigma\!\big(W_r\, a_t + U_r\, h_{t-1}\big), \\ \tilde h_t &\;=\; \tanh\!\big(W\, a_t + U\, (r_t \odot h_{t-1})\big), \\ h_t &\;=\; (1 - z_t)\, h_{t-1} \;+\; z_t\, \tilde h_t \, . \end{aligned} $$重置门 $r_t$ 决定旧状态有多少能进入候选状态;更新门 $z_t$ 决定旧状态和新候选怎么混合。直觉跟普通 GRU 完全一样,只是这里的"下一个输入"换成了图上的聚合消息 $a_t$,而不是序列里的下一个 token。

跑完 $T$ 步之后,每个节点都拿到一个"看过自己邻居、看过会话里所有环路结构、也看过每条转移触发频次"的向量 $h_v$。

实战中要注意的几点:
- 传播步数 $T$:原论文在 Yoochoose 和 Diginetica 上都只用 $T = 1$。把 $T$ 调大几乎没有提升——会话图本来就只有 10 个节点上下,一步传播信号就饱和了。
- 参数共享:GRU 单元的所有参数 $(W_z, U_z, W_r, U_r, W, U, W_a, b)$ 在所有节点、所有会话之间共享。整个模型里"会话相关"的部分只有物品表里被查到的那些行。
- 重复访问:因为节点去重,两次点击
B在传播过程中共享同一个嵌入。次序信息要等到下一步的池化阶段才显式恢复。
4. 会话表示的拼装
拿到节点表示 $\{h_1, \dots, h_n\}$ 之后,要压成一个会话向量 $s_h$。最朴素的做法——“直接用 $h_n$"——在短会话上表现意外地不差,但等于把 GGNN 学到的东西全扔掉。论文的设计是两路视图线性融合。
局部意图。 直接用最后一个点击的嵌入:
$$ s_l \;=\; h_n \, . $$这是用户"此刻在想什么"的最强单一信号。
全局上下文。 在所有节点表示上做一次软注意力,注意力的查询锚定在 $h_n$ 上:
$$ \alpha_i \;=\; q^\top\, \sigma\!\big(W_1\, h_n \;+\; W_2\, h_i \;+\; c\big), \qquad s_g \;=\; \sum_{i=1}^{n} \alpha_i\, h_i \, . $$这里 $q \in \mathbb{R}^d$ 是一个可学习的查询向量,$W_1, W_2 \in \mathbb{R}^{d\times d}$ 把"最后一次点击"和"会话里的每个物品"投到同一个打分空间里。和最后一次点击越相关的物品,权重越大;被新意图盖过去的物品,权重越小。
最终会话向量。 拼接后线性投影:
$$ s_h \;=\; W_3\, [\,s_l \,;\, s_g\,], \qquad W_3 \in \mathbb{R}^{d \times 2d} \, . $$
局部+全局的拆分思路在 NARM、STAMP 里也能看到,但 SR-GNN 多了一层关键差别:它进入全局求和的 $h_i$ 已经是结构感知的图嵌入,不是原始物品嵌入。这一步差距贡献了大部分实际收益。
5. 打分与训练
有了会话向量 $s_h$ 和物品表 $V \in \mathbb{R}^{|V|\times d}$,候选物品的打分就是内积:
$$ \hat z_i \;=\; s_h^\top\, v_i, \qquad \hat y \;=\; \mathrm{softmax}(\hat z) \, . $$训练目标是对真实下一跳的交叉熵:
$$ \mathcal{L} \;=\; -\sum_{i=1}^{|V|} y_i \log \hat y_i \, . $$几个真正会影响实现的细节:
- BPTT 但是图很小:梯度会沿着 $T$ 步 GGNN 反传。考虑到 $T$ 通常就是 1、节点数最多十几个,反传的代价跟"序列模型在几百个 token 上跑 BPTT"完全不是一个量级。
- 优化器:Adam,$\eta = 10^{-3}$,$(\beta_1, \beta_2) = (0.9, 0.999)$,所有矩阵 L2 衰减 $10^{-5}$。
- 嵌入维度:$d = 100$ 是默认值。在 Yoochoose 1/64 和 Diginetica 上把它放大到 256 或 512 只会过拟合,Recall@20 不会涨。
- 批处理:会话长度不齐,实现里需要把每个 batch 补齐到最大长度并加 mask。官方仓库 https://github.com/CRIPAC-DIG/SR-GNN/tree/master 把这一段处理得很干净,自己重写时建议直接抄它的 mask 逻辑。
- 超大物品池:当 $|V| > 10^5$ 时,全量 softmax 会变成瓶颈。把它换成 sampled softmax 或两塔召回头,SR-GNN 主体一行都不用动。
6. 图视角凭什么压过序列基线
纯序列的写法是 $h_t = \mathrm{GRU}(h_{t-1}, v_t)$。它有三个内生缺陷,恰好被图视角逐一规避:
- 回访就丢转移。 当用户点击
A -> B -> C -> B,序列模型在第二次到达B时会用新隐状态覆盖旧的,关于B -> C的信息悄悄消失。会话图把 $B \to C$ 和 $C \to B$ 同时显式保留,GGNN 把两条边都当作节点B的邻接结构来看。 - 关系结构得靠梯度自己学。 序列模型必须凭梯度信号"悟到"“不同位置上的同一物品其实是同一回事”;图的邻接矩阵则把这一事实直接塞进了模型架构。
- 信息只能单向流动。 RNN 是从左到右的;图通过入/出邻接拆分让信息双向传播,
D想从前驱B拉信息不需要等到反向一遍才能拿到。
这些差距在基准上是看得见的。SR-GNN 在标准评测——Yoochoose 1/64、Yoochoose 1/4、Diginetica 上的 Recall@20 和 MRR@20 都全面压过 POP / Item-KNN / FPMC / GRU4Rec / NARM / STAMP:

提升幅度最大的是 Diginetica——它的会话普遍更长、回访也更多,正是序列模型丢转移最严重的场景。
7. 超参数与训练配方
论文给出的默认值在 SR-GNN 这类任务里调得相当扎实,迁移到新数据集时几乎不用改:
| 超参 | 取值 | 说明 |
|---|---|---|
| 嵌入维度 $d$ | 100 | 64–128 是甜区;256+ 在常见 SBR 数据上必过拟合 |
| GGNN 传播步数 $T$ | 1 | 在 Diginetica 上 $T=2$ 略好,Yoochoose 上反而下降 |
| 优化器 | Adam | $\eta = 10^{-3}$,$(\beta_1, \beta_2) = (0.9, 0.999)$ |
| 学习率衰减 | 每 3 轮 ×0.1 | 第 3 轮起开始衰减,Recall@20 通常能再涨 0.5–1.0 个点 |
| Batch size | 100 | 50–200 都可以,敏感度低 |
| L2 权重衰减 | $10^{-5}$ | 对所有 $W_*$ 与物品表都加 |
| Dropout | GGNN 内不加;物品表 0.5 | 物品表 dropout 主要在评估时正则化长尾 |
| Early stopping | 监控 Recall@20,patience 5 | 多数实验在 8–12 轮收敛 |
一个容易踩的坑:官方预处理会过滤长度为 1 的会话和过滤出现次数 < 5 的物品。如果你直接拿论文数字对比、却没有这两步过滤,Recall@20 会比公布值低 3–5 个点,然后浪费一周去调模型而不是数据。
8. 失败模式与修法
SR-GNN 不是 SBR 的万能药。下面四类问题在生产里几乎一定会遇到,应该当成 checklist。
8.1 流行度坍塌
症状:Recall@20 看起来还行,但 top-$K$ 列表来回就那 5–10 个全站爆款,会话之间几乎没差别。多样性指标(intra-list distance、coverage@K)惨不忍睹。
原因:交叉熵 + 全量 softmax 天然偏向流行物品——它们出现在大多数正样本里,模型学到的最低损失策略就是"哪个流行就推哪个”。
修法:
- 给打分加流行度惩罚:$\hat z_i \mathrel{-}= \lambda \log \mathrm{freq}(i)$,$\lambda$ 在 $[0.1, 1.0]$ 之间按"召回率/多样性"的取舍调。
- 用逆倾向加权的 softmax,把训练时流行正样本的权重压低。
- 按流行度分布采样负样本,强迫模型去区分自己默认会推的爆款。
8.2 极短会话($n \le 3$)效果差
症状:长度 ≥ 5 的会话表现优秀,但长度 2–3 的会话被一个简单的共点击基线吊打。
原因:长度 2 的会话图就 2 个节点 1 条边,根本没有图结构可言;池化也退化成"用 $h_n$"。
修法:
- 混合服务:长度 ≤ 3 的会话路由到 item-KNN 或共点击模型,只有长度 ≥ 4 的会话才走 SR-GNN。在每个长度桶上一般都比单一模型强。
- 图增强:把最后一次点击与全站共点击图里它的 top-$k$ 邻居挂上边,给短会话"借"一些结构。
- 预训练物品嵌入:先在全站共点击图上跑 DeepWalk / node2vec,再把结果作为 SR-GNN 物品表的初始化。短会话从一开始就拿到信息丰富的嵌入,而不是随机噪声。
8.3 小数据集上过拟合
症状:训练 Recall@20 一路涨;验证 Recall@20 在第 4 轮就停滞,之后开始下滑。
原因:物品表 $V \in \mathbb{R}^{|V|\times d}$ 是参数量最大的一块,在小数据集上很容易把长尾物品的身份"背下来”。
修法:
- 把 $d$ 从 100 降到 50。
- 训练时给物品表加 dropout 0.5–0.7,同一会话内 mask 一致。
- L2 权重衰减加到 $10^{-4}$。
- early stopping patience 从 5 降到 3。
8.4 冷启动物品
症状:训练时出现次数少于 5 次的物品几乎从不被推荐,无论会话向量长什么样,它们的内积都很小。
原因:这些物品在 $V$ 里的对应行几乎拿不到梯度,一直停留在初始化附近。
修法:
- 接入内容特征(标题文本嵌入、品类、品牌等)作为旁路,令 $v_i = V[i] + g(\text{content}_i)$。冷启动物品的行就能从 $g$ 那里继承到先验。
- 用两塔召回做候选生成,把 SR-GNN 留给精排。
9. 几个值得知道的变体
下面这些扩展是后续 SBR 论文和工业系统反复出现的"标配套件":
9.1 注意力化的消息传递
把固定的 $A_s$ 换成按边计算的注意力权重:
$$ \alpha_{ij} \;=\; \mathrm{softmax}_j\!\big(\mathrm{LeakyReLU}(a^\top [W h_i \,||\, W h_j])\big), \qquad a_t^{(i)} \;=\; \sum_{j \in \mathcal N(i)} \alpha_{ij}\, W h_j \, . $$本质就是把 GAT 塞进 GGNN 单元里。当不同转移的"重要程度"差异较大时有用。
9.2 时间间隔作为边特征
会话长度从几秒到几分钟都有;2 秒内的连点和间隔 5 分钟的两次点击信号强度完全不一样。把时间间隔 $\Delta t_{u \to v}$ 编进边权:
$$ w_{u \to v} \;=\; \exp\!\big(-\beta \cdot \Delta t_{u \to v}\big) \cdot \frac{\#(u \to v)}{\mathrm{outdeg}(u)} \, . $$让 $\beta$ 可学习,最终值通常会落在 $0.05$–$0.2$ 之间($\Delta t$ 以秒为单位时)。
9.3 多任务头
在同一个 $s_h$ 上加辅助损失:
- 预测会话长度(对 $\log n$ 做回归)。
- 预测用户是否会在 24 小时内回来(二分类)。
- 预测下一次点击的品类。
辅助任务能给会话向量加一层正则化,在主任务损失噪声较大时有帮助。权重保持小($0.05$–$0.2$),它们是辅助信号,不是主目标。
10. 什么场景该用 SR-GNN,什么场景换别的
| 场景 | 建议 |
|---|---|
| 长会话且有回访($n \ge 5$) | SR-GNN,最舒服的甜区 |
| 极短会话($n \le 3$) | item-KNN 或共点击,SR-GNN 没图可用 |
| 重度冷启动 | 两塔 + 内容特征召回,SR-GNN 仅用于精排 |
| 实时延迟预算 < 5 ms | 缓存物品邻居表示,或蒸馏成 MLP 头 |
| 物品池 $ | V |
| 有强长期用户历史 | 看 SASRec、BERT4Rec 这一支带用户嵌入的序列模型 |
五点小结
- 会话即图。点击流变成有向加权图,回访和环路被原样保留在邻接里,而不是被隐状态盖掉。
- GGNN = 图消息上的 GRU。先在入/出邻接上聚合一个消息 $a_t$,再走重置门、更新门、候选状态。一步传播一般就够。
- 局部 + 全局池化。会话向量由"最后一次点击"(短期意图)和"以最后点击为查询的注意力求和"(全局上下文)线性融合而成。
- 交叉熵训练 + 内积打分。整套训练范式毫无花活——真正赢的是"嵌入里编码了图结构"这件事。
- 甜区是中等长度且带回访的会话。出了这个范围——长度 ≤ 3、冷启动、超大物品池——别在 SR-GNN 内部硬调,换成合适的搭档(KNN、内容塔、采样 softmax)才是正解。
更进一步的洞察是结构上的:会话推荐根本不是"披着推荐外衣的序列问题",而是"带序列先验的图问题"。一旦接受了这个视角,这条线后面所有进展——注意力化的 GNN(GC-SAN)、双曲嵌入(HCGR)、LLM 增强的会话模型(LLMGR)——读起来都会顺很多。