Graph Neural Networks for Learning Equivariant Representations of Neural Networks

把神经网络本身画成一张图(神经元做节点、权重做边),再让 GNN 来读它,就能得到一个对隐藏单元置换天然等变的表示。换对了对称性,预测泛化、检索相似模型、跨架构合并权重这类任务才真正变得可学。

把一个 MLP 的隐藏单元换个顺序,函数本身不变,但参数向量却完全不同——这是「在网络空间里做学习」绕不开的第一道坎。如果不尊重这种置换对称性,下游模型就需要大量容量来记忆同一个函数的不同写法,从而影响泛化和迁移。 Kofinas 等人在 ICML 2024 的论文 Graph Neural Networks for Learning Equivariant Representations of Neural Networks 提出了一种简洁的解决方案:将网络视为有向图(神经元为节点,权重为边),并使用对节点置换等变的 GNN 来读取。下面依次介绍为什么需要等变、神经图怎么构造、等变的意义、模型搭建方法以及四类下游任务及其细节与坑。


你将学到什么#

  • 为什么「逐层隐藏单元置换」才是这件事真正的对称群
  • 怎样把 MLP / CNN / Transformer 都映射到同一种「神经图」上
  • 等变和不变的差别,以及在这个语境下各自该用在哪
  • 在神经图上做消息传递,PNA + FiLM 这套组合到底加了什么
  • 四个能直接吃到等变红利的下游任务:预测泛化、网络分类、相似模型检索、模型合并
  • 一些工程细节:探测特征、归一化方式、位置嵌入、扩展性

阅读门槛#

  • 对 GNN 的基本套路(消息传递、节点特征、图级池化)有概念
  • 熟悉 MLP / CNN / Transformer 的常规结构
  • 大致知道「不变(invariance)」「等变(equivariance)」是什么

为什么「在网络空间上学习」必须考虑等变#

置换等变性:MLP 的隐藏单元对称性

近几年出现了一类新任务,将一个完整的训练好的网络视为一个数据点:

  • 预测泛化能力:不跑验证集,只看权重就估计测试精度
  • 按行为给网络分类:识别它解的任务、使用的数据集和优化器(如 SGD vs Adam,ResNet vs VGG 等)
  • 相似模型检索:在 模型库 里按「函数相似度」找类似模型
  • 元学习:从一群训练好的模型中学习规律(如哪种结构泛化好)
  • 模型合并:将多个独立训练的模型权重合并
$$f(x;\,W_1, b_1, W_2, b_2) \;=\; f(x;\,P W_1, P b_1, W_2 P^\top, b_2),$$ $$\mathcal{S} \;=\; S_{n_1} \times S_{n_2} \times \cdots \times S_{n_L},$$

也就是说,「同一个函数」在参数空间里对应着指数级多的等价点。一个无视 $\mathcal{S}$ 的下游学习器要么得自己把所有等价类都见过一遍(不现实),要么只能祈祷训练分布刚好覆盖到(同样不现实)。

图 1 把这个对称画了出来。把三个隐藏单元随便换个顺序,再相应地把 $W_1$ 的行和 $W_2$ 的列也按同样的顺序换一下,函数 $f(x)$ 一字不差,可 vec(W_1, b_1, W_2, b_2)$\mathbb{R}^d$ 里却变到了另外一个完全不同的点上。

几种最直觉的做法为什么都不行#

大多数人首先想到的几种方法,每种都会遇到不同问题。

把权重展平成一个向量: 把所有参数拼成一个 $\theta \in \mathbb{R}^d$ 喂给一个 MLP。这种表示对置换毫无抵抗力——换一下隐藏单元,$\theta$ 就完全不同了。而且 $d$ 直接绑定到具体的宽度和层数,不同尺寸的网络连放进同一个空间都做不到。

用权重的统计量: 算每一层权重张量的均值、方差、直方图、各阶矩。这样确实对置换不变,但关系信息全丢了:哪条权重连接哪两个神经元这种核心结构被完全抹平,「权重分布相似但函数行为不同」的两个网络会被压到同一个点上。

把权重矩阵当图像让 CNN 卷。 CNN 在二维网格上是平移等变的,可隐藏单元的对称是置换而不是平移——$W$ 的行和列并不是规则排在网格上的, CNN 的归纳偏置在这里用错了地方。而且 CNN 仍然要求固定形状,跨宽度迁移依旧无解。

三种失败的原因相同:对称用错、不变量选错或拓扑选错。我们需要一种本身就具有对称性的表示。

神经图:把权重组装成一张带类型的图#

正确的做法是直接画一张有向图,让它的结构和网络的计算图一一对应。

MLP 怎么变成图#

MLP 权重与作为神经图的 MLP

设一个 MLP 的层宽是 $n_0, n_1, \ldots, n_L$ ,第 $\ell$ 层的权重是 $W_\ell \in \mathbb{R}^{n_\ell \times n_{\ell-1}}$ ,偏置是 $b_\ell \in \mathbb{R}^{n_\ell}$ 。神经图的定义是:

  • 节点:每个神经元一个,总数 $\sum_\ell n_\ell$
  • 节点特征 $V$ :该神经元的偏置;可选地拼上位置 / 类型嵌入(输入/隐藏/输出、层号、激活函数类型)。
  • :每条权重一条,方向从源神经元指向目标神经元。
  • 边特征 $E$ :该权重的标量值;可选地拼上边类型嵌入(前向/残差、卷积/线性等)。

图 2 把两种表示放在一起对照。左边是「张量视角」:$(W_1, b_1, W_2, b_2)$ 各自是几个独立矩阵和向量,喂给下游学习器最自然的方式就是 vec(...),但拓扑也就一并被丢掉了。右边是「图视角」:完全一样的参数被组织成一张图,$h_1$ 的偏置写在节点上,$x_1 \to h_1$ 的权重写在边上。作为一个结构化对象,这张图在重新给隐藏节点编号时是不变的——图的同一性本来就是「同构意义下」的,标签 h1, h2, h3 不过是临时贴的便条。

为什么这种表示是「对的对象」#

神经图天然带着我们要的对称。把所有 $N = \sum_\ell n_\ell$ 个节点重新排序,对应的对称群是 $S_N$ ,比 MLP 真正的逐层置换 $\mathcal{S}$ 要大得多。这看起来过头了,但其实正合适

  • $\mathcal{S}$$S_N$ 的子群(合法的置换只能在层内做);
  • 任何对 $S_N$ 等变的模型自动也对 $\mathcal{S}$ 等变;
  • 所以一个 $S_N$ -等变的模型可以同时处理多种架构,不用为每种宽度分开训练。

这是这篇论文最实用的卖点。之前那一类工作(DeepSet 路线的逐层置换网络)只对某一个固定的 $\mathcal{S}$ 等变,意味着模型一旦训好,宽度变了就不能用。换成神经图加 GNN,整个家族一次搞定。

把 CNN、 Transformer 也装进同一种语言#

只要再做一些小调整,就能把别的层类型一起塞进神经图:

  • 卷积层: 一个 $c_\text{out} \times c_\text{in} \times k \times k$ 的卷积核会在 $c_\text{in}$ 个源通道节点和 $c_\text{out}$ 个目标通道节点之间生成一片二部块,把 $k \times k$ 的空间卷积核展平作为多维边特征。为了让不同卷积核大小可以共存,所有核都先零填充到最大尺寸。
  • 展平加全连接头: 展平后的全连接被当成 $1 \times 1$ 卷积处理,于是和卷积层在图上长得一模一样。配合自适应池化,神经图的拓扑就和输入图像分辨率解耦了。
  • 归一化层: LayerNorm / BatchNorm 的缩放 $\gamma$ 和偏移 $\beta$ 写成对角边块(每个通道一条边,边特征 $= \gamma_i$ )外加每个输出节点一个偏置 $\beta_i$ ,对角结构原原本本保留下来。
  • 残差连接: $y = x + f(x)$ 就是从源节点到对应目标节点加一条特征为 $1$ 的边——把恒等矩阵显式画出来而已。
  • 多头注意力: $W_Q^h, W_K^h, W_V^h, W_O$ 各自变成连接「输入 / 单头中间 / 输出」三组节点的带类型边块。注意力计算本身没有参数,所以图里不显式建模它,由 GNN 来近似那部分行为。
  • 激活函数: 每层用的是 ReLU 还是 GELU 还是 SiLU,作为可学习的嵌入加到对应节点特征上。

要点是用一种统一的图语言把所有标准层都装下了,于是同一个 GNN 既能吃 MLP,也能吃 CNN、 ResNet、 Transformer,下游一行架构特化的代码都不用写。

等变性:形式定义与操作含义#

不变性与等变性:相同的对称性,不同的输出

把概念再钉死一点。对图上函数 $f$ 和节点标号置换 $\pi$

  • $f$ 不变(invariant)$f(\pi \cdot G) = f(G)$图级输出(一个网络一个标量)用这个。
  • $f$ 等变(equivariant)$f(\pi \cdot G) = \pi \cdot f(G)$节点级输出(每个神经元一个向量)用这个。
$$h_v^{(\ell+1)} \;=\; \mathrm{UPDATE}\!\left(\,h_v^{(\ell)},\;\bigoplus_{u \in \mathcal{N}(v)} \mathrm{MSG}(h_u^{(\ell)}, e_{uv})\,\right),$$

并用一个对置换不变的聚合算子 $\bigoplus$ (求和、均值、最大、注意力都行)。给节点重新编号和走 GNN 这两件事是可以交换次序的:输入图换标号 $\Rightarrow$ 节点嵌入按相同方式换标号,边结构没变 $\Rightarrow$ 任何只看节点嵌入集合的标量函数也都没变。

图 4 把这两件事的实际差别拍了出来。左边,做一次图级池化(求和 / 均值 / 注意力)把节点嵌入压成一个向量 $z_G$ ,无论原图有没有被置换,得到的 $z_G$ 都是同一个——这是「预测泛化」「分类任务」这类要图级标量的场合该用的。右边,节点嵌入矩阵 $Z(G)$ 跟着输入一起被置换——这是「跨网络对齐神经元」必须有的性质,模型合并和架构编辑都靠它。

一句口诀:等变是更强的性质,不变是最后一步才用上的。流水线最干净的写法是消息传递全程等变,只在真要给图级标量的时候做一次池化把它收成不变。

模型架构:为神经图量身改造的 GNN 与 Transformer#

论文给了两个骨干,都针对一个不太常见的事实做了改造——边特征是主角(毕竟权重才是参数)。

NG-GNN:带边更新和 FiLM 调制的 PNA#

$$e_{uv}^{(\ell+1)} \;=\; \phi^{(\ell)}_E\!\left(\,e_{uv}^{(\ell)},\, h_u^{(\ell)},\, h_v^{(\ell)}\,\right),$$ $$\mathrm{MSG}(h_u, e_{uv}) \;=\; (\gamma(e_{uv}) \odot h_u) + \beta(e_{uv}),$$

其中 $\gamma, \beta$ 是两个小 MLP。权重门控来自源神经元的消息——这恰好就是真实网络在前向推理时做的事。

NG-T:带关系注意力的 Transformer#

$$V_{uv} \;=\; (\gamma(e_{uv}) \odot V_u) + \beta(e_{uv}),$$

也就是说,从 $v$$u$ 的注意力会被两者之间的权重调制。还是 FiLM 那一招,只不过搬到注意力里。经验上 NG-T 在「小网络但密图」上更强, NG-GNN 在「大网络但稀疏图」上更省。

整体流水线#

端到端 GNN 管道处理神经网络参数

图 3 把五个阶段一字排开:训好的网络 → 神经图 → $L$ 层带边更新的等变消息传递 → 图级池化 → 一个小 MLP 输出头。等变性是在第三步写进架构的:之前都是数据,之后要么继续保持等变,要么有意做一次塌缩(最后那一次池化)。

一些「看着小但真有用」的工程细节#

下面这些点单看都不起眼,论文的消融实验里每一个都贡献了肉眼可见的提升。

探测特征(probe features)。 准备一组固定的探测输入,把每个待分析网络的中间激活值都跑出来,按神经元拼到对应节点的特征里。这把函数信息直接打进了表示——而且它本身就对置换等变(探测只跟神经元交互,不跟标号交互),并且对任何「保函数不变」的参数变换都自动保持。实际做法里探测输入是学出来的,可以反传梯度上去,结果就是在那些光看权重统计量不够用的任务上有显著提升。

尊重对称的归一化: 之前的参数空间方法常常按「神经元在训练集上的均值/方差」做归一化,这其实破坏了对称性——「跨网络的第 7 个神经元」根本不是一个有意义的概念,因为神经元本来就是可置换的。论文的处理是按做归一化:每一层一个权重均值、一个权重方差、一个偏置均值、一个偏置方差,这样统计量本身就是 $\mathcal{S}$ -不变的。

位置嵌入要保住置换对称: 每个节点学一个位置嵌入,但只绑定到层号,不绑定到层内编号。同一隐藏层的所有节点共享一个位置向量,层内置换对称就守住了。输入和输出节点单独享有不同的位置嵌入——因为换它们的顺序改变函数(它们是网络对外的可见接口)。

反向边: 在每条前向边的反方向再加一条边(带自己的类型嵌入),消息传递的带宽翻倍,「反传方向」的信息一层就能传过去而不是 $L$ 层。便宜,稳定收益。

下游任务:等变性能换到什么#

预测泛化能力#

仅从权重预测泛化性能

设定:拿一组训好的网络(每个都知道测试精度),训 GNN 直接从权重回归测试精度。

图 5 直观地体现了两条路线的差别:等变版的预测点紧贴 $y = x$ ,而把权重展平喂 MLP 的版本则是一团飘忽的散点。等变性让模型不必把样本效率浪费在「同一个函数的不同写法」上,可以专心去拟合真正和泛化相关的量——谱、层间对齐、 sharpness 代理之类。

按行为给网络分类#

同样的流水线,把头换成分类:预测某个网络是哪个数据集 / 哪个任务 / 哪个优化器训出来的。一个有意思的现象是,为某个分类任务学到的图嵌入可以迁移到别的分类任务——GNN 学到的并不是某个特定分类边界,而是一个通用的「网络的特征空间」。

相似模型检索#

把整个 模型库 用 GNN 嵌入一遍,然后用余弦相似度去检索。功能相似的网络(比如两个不同种子训出的 CIFAR-10 分类器)会被嵌到相近的位置,尽管它们在参数空间里的距离基本是随机的。这正是等变性该带来的:嵌入度量是由网络实际算什么诱导出来的,而不是由它任意的参数化方式决定。

通过神经元对齐做模型合并#

这里要用等变的节点级嵌入,不是池化后的图级向量。把两个网络的神经元按它们的节点嵌入距离做匹配(匈牙利或最优传输),按对齐结果合并权重。传统的「拿一组探测输入比对激活」做对齐,在这套框架里反倒成了一个特例:探测激活只是 GNN 消费的诸多节点特征之一。

和几种基线放在一起对比#

方法等变?跨架构?保留拓扑?
展平权重 + MLP否(维度随宽度变)
权重统计量是(仅不变)否(关系丢失)
在权重矩阵上做 CNN否(平移 $\ne$ 置换)部分部分
DeepSet 风格的逐层置换网络是(绑死一个 $\mathcal{S}$否(只能跑一种架构)部分
神经图 + GNN (本文)是($S_N$ -等变)

论文里的实验整体也对得上这张表:在每种架构上,神经图 + GNN 都能与该架构上专门的方法持平或更好,而且它是唯一一个能同时吃所有架构的方案。

局限与待解的问题#

  • 规模: 一个十亿参数级别模型的神经图边数也是十亿级。即便用稀疏 GNN 库,目前这套流程在百万参数量级很舒服,到十亿就吃力了。按层 / 按块做局部神经图是显然的下一步。
  • 架构覆盖: MLP / CNN / ResNet / Transformer 都有了,但任意计算图(混合专家、动态路由、递归结构)还在外面。
  • 探测设计: 探测输入虽然能学,但用什么样的探测(对抗、随机、分布内、分布外)对哪种下游最优,目前还基本是经验问题。
  • 非对称初始化下的行为: 整套故事预设参数遵循 $\mathcal{S}$ 的轨道结构。某些权重共享方案、结构化稀疏可能打破这一假设,需要专门建模。

几点要带走的#

  1. 真正要对抗的对称是逐层隐藏置换,不是「整个 $\theta$ 」。神经图把这个对称结构精准地编进了表示。
  2. GNN 自带置换等变,配上一张拓扑复刻网络结构的图,正确的归纳偏置就自动到位了——不需要任何手工设计的参数共享方案。
  3. 一个模型,多种架构: 由于 $\mathcal{S}$ 是 GNN 等变的更大群 $S_N$ 的子群,同一个训好的 GNN 就能消化所有兼容架构。
  4. 池化只放在最后一步: 整条消息传递保持节点级等变,需要图级标量时再做一次不变池化。
  5. 真正的红利在「把网络当数据点」的元任务上:预测泛化、检索相似模型、合并权重。这些任务以前要么忽略对称性、要么手写一套对称约束,现在都可以省下来。

参考文献#

读有所得?

GitHub 关注我 → 新文周更

GitHub