SR-GNN —— Session-based Recommendation with Graph Neural Networks

SR-GNN 把一段点击会话拆成有向加权图,再用门控 GNN 做下一跳预测。本文系统讲清会话图构建、GGNN 更新、局部+全局池化、训练细节、基准对比,以及决定要不要在生产里用它的几类失败模式。

用户依次点击 A、B、C、B、D。把它喂给一个序列模型,得到的是五个 token 折叠出的隐状态;交给 SR-GNN,得到的是一张有向图——边 B -> C 即便用户回到 B 也仍然存在,节点 B 只出现一次(它的入边和出邻居都会贡献到它的表示上),整段点击的拓扑都被原样保留在邻接矩阵里。这就是 SR-GNN(Wu 等, AAAI 2019) 在多个会话推荐基准上稳稳压过 GRU4Rec、NARM 等纯序列模型的根本原因。

这篇笔记把 SR-GNN 从头到尾拆开讲:会话图怎么建、门控 GNN(GGNN)在图上做了什么、会话向量为什么要做"局部+全局"两路融合、打分和训练有哪些坑,以及它在哪些情况下会崩——目的不是说服你用它,而是让你能准确判断什么时候不要用它。

你会学到什么

  • 怎么从点击流构造一张有向加权会话图,以及入/出邻接矩阵怎么喂给 GGNN
  • 门控 GNN 更新可以拆成"先聚合消息、再走 GRU 单元"两步
  • 会话池化为什么要做局部+全局,以及注意力为什么要锚定在最后一个点击上
  • 训练目标、几个真正合理的超参数,以及在这种小图上 BPTT 究竟意味着什么
  • 在带回访、长度中等的会话上,图视角为什么稳赢序列模型
  • 几类典型失败模式(短会话、流行度坍塌、冷启动物品、超大物品池)和对应的修法
  • 几个值得知道的常见变体(attention 化的 GGNN、时间间隔边、多任务头),以及它们在什么场景下真正有用

前置准备

  • 熟悉消息传递与 GNN 基础术语(邻接矩阵、传播步数)
  • 了解 GRU/LSTM 的门控结构
  • 会用推荐评测指标 Recall@K、MRR@K,知道 sampled softmax 是干什么的

1. 背景:会话推荐到底特殊在哪

会话推荐的限制条件比常见 CF 苛刻得多:没有稳定的长期用户画像可用,能看到的就是当前这一段短点击序列 $s = [v_{s,1}, v_{s,2}, \dots, v_{s,n}]$,物品池为 $V = \{v_1, \dots, v_{|V|}\}$,目标是预测下一跳 $v_{s,n+1}$。模型最终输出在 $|V|$ 上的打分向量 $\hat z$,取 top-$K$ 推荐。

这种设定让传统 CF 和纯 RNN 都不太顺手,原因有两点:

  • 上下文极短:典型会话只有 2–10 个点击,模型必须从极少的信号里反推用户当下的"意图"。
  • 物品会重复出现,意图也不单调:用户会反复比较,绕回来再点同一个商品。同一物品在一段会话里出现多次是常态,“靠后的点击"也未必比"靠前的点击"代表更稳定的偏好。

纯序列模型把这些点击压成一个隐状态,就不可避免地丢掉了"重复访问的物品之间还有结构信息"这一事实。SR-GNN 的贡献,就是把这种结构原原本本以图的形式保留下来,剩下交给消息传递。

2. 会话图的构建

对每段会话 $s$,SR-GNN 构造有向图 $G_s = (V_s, E_s)$,节点集 $V_s$ 是这段会话里去重后的物品(不论某物品被点击几次,作为节点只出现一次),边集 $E_s$ 包含每一次观察到的转移 $u \to v$。边权按起点的出度归一化:

$$ w_{u \to v} \;=\; \frac{\#(u \to v)}{\mathrm{outdeg}(u)} \, . $$

由此自然得到两张邻接矩阵:入度邻接 $A^{(\text{in})}$(第 $i$ 行告诉节点 $i$ 谁会喂信息给它)和出度邻接 $A^{(\text{out})}$(第 $i$ 行告诉节点 $i$ 它会喂信息给谁)。两张矩阵横着拼成 $|V_s| \times 2|V_s|$ 的 $A_s$,作为 GGNN 真正消费的输入。

会话图构建:点击流 A,B,C,B,D 对应的有向加权图;即使用户回到 B,所有跳转关系也都被保留

具体看点击流 A, B, C, B, D 的边权计算:

出现次数起点出度边权
$A \to B$111.00
$B \to C$120.50
$C \to B$111.00
$B \to D$120.50

如果直接用 GRU 跑这串序列,第二次访问 B 时新的隐状态会把”B -> C“这条转移悄悄覆盖掉;图的视角则一字不漏地保留下来。

3. 门控 GNN 的传播

每个物品 $v$ 都有一个 $d$ 维向量 $h_v$,从全局物品表 $V \in \mathbb{R}^{|V|\times d}$ 里查出来。SR-GNN 在会话图上跑 $T$ 轮消息传递,用的是 门控图神经网络 GGNN(Li 等, 2016)——本质上就是一个以"聚合后的消息"为输入的 GRU 单元。

第 $t$ 步,节点 $i$ 收到的聚合消息是

$$ a_t^{(i)} \;=\; A_{s,i:}\, \big[h_1^{(t-1)}, \dots, h_n^{(t-1)}\big]^\top\, W_a \;+\; b, $$

其中 $A_{s,i:}$ 是拼接邻接 $[A^{(\text{in})}\,|\,A^{(\text{out})}]$ 的第 $i$ 行。然后用 GRU 风格的门做状态更新:

$$ \begin{aligned} z_t &\;=\; \sigma\!\big(W_z\, a_t + U_z\, h_{t-1}\big), \\ r_t &\;=\; \sigma\!\big(W_r\, a_t + U_r\, h_{t-1}\big), \\ \tilde h_t &\;=\; \tanh\!\big(W\, a_t + U\, (r_t \odot h_{t-1})\big), \\ h_t &\;=\; (1 - z_t)\, h_{t-1} \;+\; z_t\, \tilde h_t \, . \end{aligned} $$

重置门 $r_t$ 决定旧状态有多少能进入候选状态;更新门 $z_t$ 决定旧状态和新候选怎么混合。直觉跟普通 GRU 完全一样,只是这里的"下一个输入"换成了图上的聚合消息 $a_t$,而不是序列里的下一个 token。

SR-GNN 端到端流程:会话图 -> GGNN 传播 -> 节点表示 -> 注意力池化 -> 物品池 softmax

跑完 $T$ 步之后,每个节点都拿到一个"看过自己邻居、看过会话里所有环路结构、也看过每条转移触发频次"的向量 $h_v$。

门控 GNN 单步更新:入/出邻接行先聚合出消息 $a_t$,再驱动一个 GRU 风格的状态更新

实战中要注意的几点:

  • 传播步数 $T$:原论文在 Yoochoose 和 Diginetica 上都只用 $T = 1$。把 $T$ 调大几乎没有提升——会话图本来就只有 10 个节点上下,一步传播信号就饱和了。
  • 参数共享:GRU 单元的所有参数 $(W_z, U_z, W_r, U_r, W, U, W_a, b)$ 在所有节点、所有会话之间共享。整个模型里"会话相关"的部分只有物品表里被查到的那些行。
  • 重复访问:因为节点去重,两次点击 B 在传播过程中共享同一个嵌入。次序信息要等到下一步的池化阶段才显式恢复。

4. 会话表示的拼装

拿到节点表示 $\{h_1, \dots, h_n\}$ 之后,要压成一个会话向量 $s_h$。最朴素的做法——“直接用 $h_n$"——在短会话上表现意外地不差,但等于把 GGNN 学到的东西全扔掉。论文的设计是两路视图线性融合。

局部意图。 直接用最后一个点击的嵌入:

$$ s_l \;=\; h_n \, . $$

这是用户"此刻在想什么"的最强单一信号。

全局上下文。 在所有节点表示上做一次软注意力,注意力的查询锚定在 $h_n$ 上

$$ \alpha_i \;=\; q^\top\, \sigma\!\big(W_1\, h_n \;+\; W_2\, h_i \;+\; c\big), \qquad s_g \;=\; \sum_{i=1}^{n} \alpha_i\, h_i \, . $$

这里 $q \in \mathbb{R}^d$ 是一个可学习的查询向量,$W_1, W_2 \in \mathbb{R}^{d\times d}$ 把"最后一次点击"和"会话里的每个物品"投到同一个打分空间里。和最后一次点击越相关的物品,权重越大;被新意图盖过去的物品,权重越小。

最终会话向量。 拼接后线性投影:

$$ s_h \;=\; W_3\, [\,s_l \,;\, s_g\,], \qquad W_3 \in \mathbb{R}^{d \times 2d} \, . $$

会话池化:注意力权重锚定在最后一次点击上;局部意图 + 全局上下文线性融合成会话向量

局部+全局的拆分思路在 NARM、STAMP 里也能看到,但 SR-GNN 多了一层关键差别:它进入全局求和的 $h_i$ 已经是结构感知的图嵌入,不是原始物品嵌入。这一步差距贡献了大部分实际收益。

5. 打分与训练

有了会话向量 $s_h$ 和物品表 $V \in \mathbb{R}^{|V|\times d}$,候选物品的打分就是内积:

$$ \hat z_i \;=\; s_h^\top\, v_i, \qquad \hat y \;=\; \mathrm{softmax}(\hat z) \, . $$

训练目标是对真实下一跳的交叉熵:

$$ \mathcal{L} \;=\; -\sum_{i=1}^{|V|} y_i \log \hat y_i \, . $$

几个真正会影响实现的细节:

  • BPTT 但是图很小:梯度会沿着 $T$ 步 GGNN 反传。考虑到 $T$ 通常就是 1、节点数最多十几个,反传的代价跟"序列模型在几百个 token 上跑 BPTT"完全不是一个量级。
  • 优化器:Adam,$\eta = 10^{-3}$,$(\beta_1, \beta_2) = (0.9, 0.999)$,所有矩阵 L2 衰减 $10^{-5}$。
  • 嵌入维度:$d = 100$ 是默认值。在 Yoochoose 1/64 和 Diginetica 上把它放大到 256 或 512 只会过拟合,Recall@20 不会涨。
  • 批处理:会话长度不齐,实现里需要把每个 batch 补齐到最大长度并加 mask。官方仓库 https://github.com/CRIPAC-DIG/SR-GNN/tree/master 把这一段处理得很干净,自己重写时建议直接抄它的 mask 逻辑。
  • 超大物品池:当 $|V| > 10^5$ 时,全量 softmax 会变成瓶颈。把它换成 sampled softmax 或两塔召回头,SR-GNN 主体一行都不用动。

6. 图视角凭什么压过序列基线

纯序列的写法是 $h_t = \mathrm{GRU}(h_{t-1}, v_t)$。它有三个内生缺陷,恰好被图视角逐一规避:

  • 回访就丢转移。 当用户点击 A -> B -> C -> B,序列模型在第二次到达 B 时会用新隐状态覆盖旧的,关于 B -> C 的信息悄悄消失。会话图把 $B \to C$ 和 $C \to B$ 同时显式保留,GGNN 把两条边都当作节点 B 的邻接结构来看。
  • 关系结构得靠梯度自己学。 序列模型必须凭梯度信号"悟到"“不同位置上的同一物品其实是同一回事”;图的邻接矩阵则把这一事实直接塞进了模型架构。
  • 信息只能单向流动。 RNN 是从左到右的;图通过入/出邻接拆分让信息双向传播,D 想从前驱 B 拉信息不需要等到反向一遍才能拿到。

这些差距在基准上是看得见的。SR-GNN 在标准评测——Yoochoose 1/64、Yoochoose 1/4、Diginetica 上的 Recall@20 和 MRR@20 都全面压过 POP / Item-KNN / FPMC / GRU4Rec / NARM / STAMP:

基准对比:SR-GNN vs 序列推荐基线在 Yoochoose 1/64、Yoochoose 1/4、Diginetica 上的 Recall@20 与 MRR@20

提升幅度最大的是 Diginetica——它的会话普遍更长、回访也更多,正是序列模型丢转移最严重的场景。

7. 超参数与训练配方

论文给出的默认值在 SR-GNN 这类任务里调得相当扎实,迁移到新数据集时几乎不用改:

超参取值说明
嵌入维度 $d$10064–128 是甜区;256+ 在常见 SBR 数据上必过拟合
GGNN 传播步数 $T$1在 Diginetica 上 $T=2$ 略好,Yoochoose 上反而下降
优化器Adam$\eta = 10^{-3}$,$(\beta_1, \beta_2) = (0.9, 0.999)$
学习率衰减每 3 轮 ×0.1第 3 轮起开始衰减,Recall@20 通常能再涨 0.5–1.0 个点
Batch size10050–200 都可以,敏感度低
L2 权重衰减$10^{-5}$对所有 $W_*$ 与物品表都加
DropoutGGNN 内不加;物品表 0.5物品表 dropout 主要在评估时正则化长尾
Early stopping监控 Recall@20,patience 5多数实验在 8–12 轮收敛

一个容易踩的坑:官方预处理会过滤长度为 1 的会话过滤出现次数 < 5 的物品。如果你直接拿论文数字对比、却没有这两步过滤,Recall@20 会比公布值低 3–5 个点,然后浪费一周去调模型而不是数据。

8. 失败模式与修法

SR-GNN 不是 SBR 的万能药。下面四类问题在生产里几乎一定会遇到,应该当成 checklist。

8.1 流行度坍塌

症状:Recall@20 看起来还行,但 top-$K$ 列表来回就那 5–10 个全站爆款,会话之间几乎没差别。多样性指标(intra-list distance、coverage@K)惨不忍睹。

原因:交叉熵 + 全量 softmax 天然偏向流行物品——它们出现在大多数正样本里,模型学到的最低损失策略就是"哪个流行就推哪个”。

修法

  • 给打分加流行度惩罚:$\hat z_i \mathrel{-}= \lambda \log \mathrm{freq}(i)$,$\lambda$ 在 $[0.1, 1.0]$ 之间按"召回率/多样性"的取舍调。
  • 逆倾向加权的 softmax,把训练时流行正样本的权重压低。
  • 按流行度分布采样负样本,强迫模型去区分自己默认会推的爆款。

8.2 极短会话($n \le 3$)效果差

症状:长度 ≥ 5 的会话表现优秀,但长度 2–3 的会话被一个简单的共点击基线吊打。

原因:长度 2 的会话图就 2 个节点 1 条边,根本没有图结构可言;池化也退化成"用 $h_n$"。

修法

  • 混合服务:长度 ≤ 3 的会话路由到 item-KNN 或共点击模型,只有长度 ≥ 4 的会话才走 SR-GNN。在每个长度桶上一般都比单一模型强。
  • 图增强:把最后一次点击与全站共点击图里它的 top-$k$ 邻居挂上边,给短会话"借"一些结构。
  • 预训练物品嵌入:先在全站共点击图上跑 DeepWalk / node2vec,再把结果作为 SR-GNN 物品表的初始化。短会话从一开始就拿到信息丰富的嵌入,而不是随机噪声。

8.3 小数据集上过拟合

症状:训练 Recall@20 一路涨;验证 Recall@20 在第 4 轮就停滞,之后开始下滑。

原因:物品表 $V \in \mathbb{R}^{|V|\times d}$ 是参数量最大的一块,在小数据集上很容易把长尾物品的身份"背下来”。

修法

  • 把 $d$ 从 100 降到 50。
  • 训练时给物品表加 dropout 0.5–0.7,同一会话内 mask 一致。
  • L2 权重衰减加到 $10^{-4}$。
  • early stopping patience 从 5 降到 3。

8.4 冷启动物品

症状:训练时出现次数少于 5 次的物品几乎从不被推荐,无论会话向量长什么样,它们的内积都很小。

原因:这些物品在 $V$ 里的对应行几乎拿不到梯度,一直停留在初始化附近。

修法

  • 接入内容特征(标题文本嵌入、品类、品牌等)作为旁路,令 $v_i = V[i] + g(\text{content}_i)$。冷启动物品的行就能从 $g$ 那里继承到先验。
  • 两塔召回做候选生成,把 SR-GNN 留给精排。

9. 几个值得知道的变体

下面这些扩展是后续 SBR 论文和工业系统反复出现的"标配套件":

9.1 注意力化的消息传递

把固定的 $A_s$ 换成按边计算的注意力权重:

$$ \alpha_{ij} \;=\; \mathrm{softmax}_j\!\big(\mathrm{LeakyReLU}(a^\top [W h_i \,||\, W h_j])\big), \qquad a_t^{(i)} \;=\; \sum_{j \in \mathcal N(i)} \alpha_{ij}\, W h_j \, . $$

本质就是把 GAT 塞进 GGNN 单元里。当不同转移的"重要程度"差异较大时有用。

9.2 时间间隔作为边特征

会话长度从几秒到几分钟都有;2 秒内的连点和间隔 5 分钟的两次点击信号强度完全不一样。把时间间隔 $\Delta t_{u \to v}$ 编进边权:

$$ w_{u \to v} \;=\; \exp\!\big(-\beta \cdot \Delta t_{u \to v}\big) \cdot \frac{\#(u \to v)}{\mathrm{outdeg}(u)} \, . $$

让 $\beta$ 可学习,最终值通常会落在 $0.05$–$0.2$ 之间($\Delta t$ 以秒为单位时)。

9.3 多任务头

在同一个 $s_h$ 上加辅助损失:

  • 预测会话长度(对 $\log n$ 做回归)。
  • 预测用户是否会在 24 小时内回来(二分类)。
  • 预测下一次点击的品类

辅助任务能给会话向量加一层正则化,在主任务损失噪声较大时有帮助。权重保持小($0.05$–$0.2$),它们是辅助信号,不是主目标。

10. 什么场景该用 SR-GNN,什么场景换别的

场景建议
长会话且有回访($n \ge 5$)SR-GNN,最舒服的甜区
极短会话($n \le 3$)item-KNN 或共点击,SR-GNN 没图可用
重度冷启动两塔 + 内容特征召回,SR-GNN 仅用于精排
实时延迟预算 < 5 ms缓存物品邻居表示,或蒸馏成 MLP 头
物品池 $V
有强长期用户历史看 SASRec、BERT4Rec 这一支带用户嵌入的序列模型

五点小结

  • 会话即图。点击流变成有向加权图,回访和环路被原样保留在邻接里,而不是被隐状态盖掉。
  • GGNN = 图消息上的 GRU。先在入/出邻接上聚合一个消息 $a_t$,再走重置门、更新门、候选状态。一步传播一般就够。
  • 局部 + 全局池化。会话向量由"最后一次点击"(短期意图)和"以最后点击为查询的注意力求和"(全局上下文)线性融合而成。
  • 交叉熵训练 + 内积打分。整套训练范式毫无花活——真正赢的是"嵌入里编码了图结构"这件事。
  • 甜区是中等长度且带回访的会话。出了这个范围——长度 ≤ 3、冷启动、超大物品池——别在 SR-GNN 内部硬调,换成合适的搭档(KNN、内容塔、采样 softmax)才是正解。

更进一步的洞察是结构上的:会话推荐根本不是"披着推荐外衣的序列问题",而是"带序列先验的图问题"。一旦接受了这个视角,这条线后面所有进展——注意力化的 GNN(GC-SAN)、双曲嵌入(HCGR)、LLM 增强的会话模型(LLMGR)——读起来都会顺很多。

Liked this piece?

Follow on GitHub for the next one — usually one a week.

GitHub