Graph Neural Networks for Learning Equivariant Representations of Neural Networks

把神经网络本身画成一张图(神经元做节点、权重做边),再让 GNN 来读它,就能得到一个对隐藏单元置换天然等变的表示。换对了对称性,预测泛化、检索相似模型、跨架构合并权重这类任务才真正变得可学。

把一个 MLP 的隐藏单元换个顺序,函数本身一点没变,可参数向量却换了一副面孔——这是「在网络空间里做学习」绕不开的第一道坎。如果表示方式不尊重这种置换对称性,下游模型就要花大量容量去记忆「同一个函数的不同写法」,泛化和迁移都谈不上。Kofinas 等人在 ICML 2024 的这篇 Graph Neural Networks for Learning Equivariant Representations of Neural Networks 给的解法非常干净:把网络本身当成一张有向图(神经元当节点、权重当边),再用一个本身就对节点置换等变的 GNN 去读它。下面按照「为什么需要等变」「神经图怎么构造」「等变到底意味着什么」「模型怎么搭」「四类下游任务」「细节与坑」的顺序展开。

你会学到什么

  • 为什么「逐层隐藏单元置换」才是这件事真正的对称群
  • 怎样把 MLP / CNN / Transformer 都映射到同一种「神经图」上
  • 等变和不变的差别,以及在这个语境下各自该用在哪
  • 在神经图上做消息传递,PNA + FiLM 这套组合到底加了什么
  • 四个能直接吃到等变红利的下游任务:预测泛化、网络分类、相似模型检索、模型合并
  • 一些工程细节:探测特征、归一化方式、位置嵌入、扩展性

阅读门槛

  • 对 GNN 的基本套路(消息传递、节点特征、图级池化)有概念
  • 熟悉 MLP / CNN / Transformer 的常规结构
  • 大致知道「不变(invariance)」「等变(equivariance)」是什么

为什么「在网络空间上学习」必须考虑等变

近几年逐渐冒出一类新任务,它们把一个完整的训练好的网络当成一个数据点:

  • 预测泛化能力:不跑验证集,只看权重就估计测试精度
  • 按行为给网络分类:识别它解的是什么任务、用了什么数据集、用了什么优化器(SGD vs Adam,ResNet vs VGG 等等)
  • 相似模型检索:在 model zoo 里按「函数相似度」找类似模型
  • 元学习:在一群训好的模型上学规律(比如哪种结构泛化好)
  • 模型合并:把多个独立训练的模型权重合到一起

这五件事都被同一个麻烦绊住:MLP 的参数空间上有一个很大的离散对称群,群中的每个元素都让函数保持不变。对单个隐藏层和置换矩阵 $P$,

$$ f(x;\,W_1, b_1, W_2, b_2) \;=\; f(x;\,P W_1, P b_1, W_2 P^\top, b_2), $$

每一隐藏层都独立地拥有这种自由度。设各层宽度为 $n_1, \ldots, n_L$,整个对称群就是直积

$$ \mathcal{S} \;=\; S_{n_1} \times S_{n_2} \times \cdots \times S_{n_L}, $$

也就是说,「同一个函数」在参数空间里对应着指数级多的等价点。一个无视 $\mathcal{S}$ 的下游学习器要么得自己把所有等价类都见过一遍(不现实),要么只能祈祷训练分布刚好覆盖到(同样不现实)。

置换等变性:MLP 的隐藏单元对称

图 1 把这个对称画了出来。把三个隐藏单元随便换个顺序,再相应地把 $W_1$ 的行和 $W_2$ 的列也按同样的顺序换一下,函数 $f(x)$ 一字不差,可 vec(W_1, b_1, W_2, b_2) 在 $\mathbb{R}^d$ 里却变到了另外一个完全不同的点上。

几种最直觉的做法为什么都不行

绝大多数人最先会想到的几种方法,每一种都恰好踩中一个不同的坑。

把权重展平成一个向量。 把所有参数拼成一个 $\theta \in \mathbb{R}^d$ 喂给一个 MLP。这种表示对置换毫无抵抗力——换一下隐藏单元,$\theta$ 就完全不同了。而且 $d$ 直接绑定到具体的宽度和层数,不同尺寸的网络连放进同一个空间都做不到。

用权重的统计量。 算每一层权重张量的均值、方差、直方图、各阶矩。这样确实对置换不变,但关系信息全丢了:哪条权重连接哪两个神经元这种核心结构被完全抹平,「权重分布相似但函数行为不同」的两个网络会被压到同一个点上。

把权重矩阵当图像让 CNN 卷。 CNN 在二维网格上是平移等变的,可隐藏单元的对称是置换而不是平移——$W$ 的行和列并不是规则排在网格上的,CNN 的归纳偏置在这里用错了地方。而且 CNN 仍然要求固定形状,跨宽度迁移依旧无解。

三种失败的本质都一样:对称用错了,或者不变量选错了,或者拓扑选错了。我们要的是一种本身就「长成对称模样」的表示。

神经图:把权重组装成一张带类型的图

正确的做法是直接画一张有向图,让它的结构和网络的计算图一一对应。

MLP 怎么变成图

设一个 MLP 的层宽是 $n_0, n_1, \ldots, n_L$,第 $\ell$ 层的权重是 $W_\ell \in \mathbb{R}^{n_\ell \times n_{\ell-1}}$,偏置是 $b_\ell \in \mathbb{R}^{n_\ell}$。神经图的定义是:

  • 节点:每个神经元一个,总数 $\sum_\ell n_\ell$。
  • 节点特征 $V$:该神经元的偏置;可选地拼上位置 / 类型嵌入(输入/隐藏/输出、层号、激活函数类型)。
  • :每条权重一条,方向从源神经元指向目标神经元。
  • 边特征 $E$:该权重的标量值;可选地拼上边类型嵌入(前向/残差、卷积/线性等)。

把 MLP 的参数张量翻译成神经图

图 2 把两种表示放在一起对照。左边是「张量视角」:$(W_1, b_1, W_2, b_2)$ 各自是几个独立矩阵和向量,喂给下游学习器最自然的方式就是 vec(...),但拓扑也就一并被丢掉了。右边是「图视角」:完全一样的参数被组织成一张图,$h_1$ 的偏置写在节点上,$x_1 \to h_1$ 的权重写在边上。作为一个结构化对象,这张图在重新给隐藏节点编号时是不变的——图的同一性本来就是「同构意义下」的,标签 h1, h2, h3 不过是临时贴的便条。

为什么这种表示是「对的对象」

神经图天然带着我们要的对称。把所有 $N = \sum_\ell n_\ell$ 个节点重新排序,对应的对称群是 $S_N$,比 MLP 真正的逐层置换 $\mathcal{S}$ 要大得多。这看起来过头了,但其实正合适

  • $\mathcal{S}$ 是 $S_N$ 的子群(合法的置换只能在层内做);
  • 任何对 $S_N$ 等变的模型自动也对 $\mathcal{S}$ 等变;
  • 所以一个 $S_N$-等变的模型可以同时处理多种架构,不用为每种宽度分开训练。

这是这篇论文最实用的卖点。之前那一类工作(DeepSet 路线的逐层置换网络)只对某一个固定的 $\mathcal{S}$ 等变,意味着模型一旦训好,宽度变了就不能用。换成神经图加 GNN,整个家族一次搞定。

把 CNN、Transformer 也装进同一种语言

只要再做一些小调整,就能把别的层类型一起塞进神经图:

  • 卷积层。 一个 $c_\text{out} \times c_\text{in} \times k \times k$ 的卷积核会在 $c_\text{in}$ 个源通道节点和 $c_\text{out}$ 个目标通道节点之间生成一片二部块,把 $k \times k$ 的空间卷积核展平作为多维边特征。为了让不同卷积核大小可以共存,所有核都先零填充到最大尺寸。
  • 展平加全连接头。 展平后的全连接被当成 $1 \times 1$ 卷积处理,于是和卷积层在图上长得一模一样。配合自适应池化,神经图的拓扑就和输入图像分辨率解耦了。
  • 归一化层。 LayerNorm / BatchNorm 的缩放 $\gamma$ 和偏移 $\beta$ 写成对角边块(每个通道一条边,边特征 $= \gamma_i$)外加每个输出节点一个偏置 $\beta_i$,对角结构原原本本保留下来。
  • 残差连接。 $y = x + f(x)$ 就是从源节点到对应目标节点加一条特征为 $1$ 的边——把恒等矩阵显式画出来而已。
  • 多头注意力。 $W_Q^h, W_K^h, W_V^h, W_O$ 各自变成连接「输入 / 单头中间 / 输出」三组节点的带类型边块。注意力计算本身没有参数,所以图里不显式建模它,由 GNN 来近似那部分行为。
  • 激活函数。 每层用的是 ReLU 还是 GELU 还是 SiLU,作为可学习的嵌入加到对应节点特征上。

要点是用一种统一的图语言把所有标准层都装下了,于是同一个 GNN 既能吃 MLP,也能吃 CNN、ResNet、Transformer,下游一行架构特化的代码都不用写。

等变性:形式定义与操作含义

把概念再钉死一点。对图上函数 $f$ 和节点标号置换 $\pi$:

  • $f$ 不变(invariant):$f(\pi \cdot G) = f(G)$。图级输出(一个网络一个标量)用这个。
  • $f$ 等变(equivariant):$f(\pi \cdot G) = \pi \cdot f(G)$。节点级输出(每个神经元一个向量)用这个。

标准的消息传递 GNN 天生等变,因为它对每个节点用同一个更新函数,

$$ h_v^{(\ell+1)} \;=\; \mathrm{UPDATE}\!\left(\,h_v^{(\ell)},\;\bigoplus_{u \in \mathcal{N}(v)} \mathrm{MSG}(h_u^{(\ell)}, e_{uv})\,\right), $$

并用一个对置换不变的聚合算子 $\bigoplus$(求和、均值、最大、注意力都行)。给节点重新编号和走 GNN 这两件事是可以交换次序的:输入图换标号 $\Rightarrow$ 节点嵌入按相同方式换标号,边结构没变 $\Rightarrow$ 任何只看节点嵌入集合的标量函数也都没变。

不变与等变:同一个对称,输出形式不同

图 4 把这两件事的实际差别拍了出来。左边,做一次图级池化(求和 / 均值 / 注意力)把节点嵌入压成一个向量 $z_G$,无论原图有没有被置换,得到的 $z_G$ 都是同一个——这是「预测泛化」「分类任务」这类要图级标量的场合该用的。右边,节点嵌入矩阵 $Z(G)$ 跟着输入一起被置换——这是「跨网络对齐神经元」必须有的性质,模型合并和架构编辑都靠它。

一句口诀:等变是更强的性质,不变是最后一步才用上的。流水线最干净的写法是消息传递全程等变,只在真要给图级标量的时候做一次池化把它收成不变。

模型架构:为神经图量身改造的 GNN 与 Transformer

论文给了两个骨干,都针对一个不太常见的事实做了改造——边特征是主角(毕竟权重才是参数)。

NG-GNN:带边更新和 FiLM 调制的 PNA

底座选的是 PNA(Principal Neighborhood Aggregation),原因是它支持边特征,并且并行用了好几种聚合算子(均值、最大、标准差、按节点度做缩放)。原版 PNA 不更新边,论文加了一个逐层的边 MLP

$$ e_{uv}^{(\ell+1)} \;=\; \phi^{(\ell)}_E\!\left(\,e_{uv}^{(\ell)},\, h_u^{(\ell)},\, h_v^{(\ell)}\,\right), $$

让边特征也随深度演化。为了把「权重(边)」和「神经元状态(节点)」之间的乘性交互表达充分,消息函数走 FiLM 调制:

$$ \mathrm{MSG}(h_u, e_{uv}) \;=\; (\gamma(e_{uv}) \odot h_u) + \beta(e_{uv}), $$

其中 $\gamma, \beta$ 是两个小 MLP。权重门控来自源神经元的消息——这恰好就是真实网络在前向推理时做的事。

NG-T:带关系注意力的 Transformer

Transformer 变体把神经图当全连接图看,用关系注意力:把边特征作为 value 矩阵的一个偏置项注入,

$$ V_{uv} \;=\; (\gamma(e_{uv}) \odot V_u) + \beta(e_{uv}), $$

也就是说,从 $v$ 到 $u$ 的注意力会被两者之间的权重调制。还是 FiLM 那一招,只不过搬到注意力里。经验上 NG-T 在「小网络但密图」上更强,NG-GNN 在「大网络但稀疏图」上更省。

整体流水线

GNN 处理神经网络参数的端到端流水线

图 3 把五个阶段一字排开:训好的网络 → 神经图 → $L$ 层带边更新的等变消息传递 → 图级池化 → 一个小 MLP 输出头。等变性是在第三步写进架构的:之前都是数据,之后要么继续保持等变,要么有意做一次塌缩(最后那一次池化)。

一些「看着小但真有用」的工程细节

下面这些点单看都不起眼,论文的消融实验里每一个都贡献了肉眼可见的提升。

探测特征(probe features)。 准备一组固定的探测输入,把每个待分析网络的中间激活值都跑出来,按神经元拼到对应节点的特征里。这把函数信息直接打进了表示——而且它本身就对置换等变(探测只跟神经元交互,不跟标号交互),并且对任何「保函数不变」的参数变换都自动保持。实际做法里探测输入是学出来的,可以反传梯度上去,结果就是在那些光看权重统计量不够用的任务上有显著提升。

尊重对称的归一化。 之前的参数空间方法常常按「神经元在训练集上的均值/方差」做归一化,这其实破坏了对称性——「跨网络的第 7 个神经元」根本不是一个有意义的概念,因为神经元本来就是可置换的。论文的处理是按做归一化:每一层一个权重均值、一个权重方差、一个偏置均值、一个偏置方差,这样统计量本身就是 $\mathcal{S}$-不变的。

位置嵌入要保住置换对称。 每个节点学一个位置嵌入,但只绑定到层号,不绑定到层内编号。同一隐藏层的所有节点共享一个位置向量,层内置换对称就守住了。输入和输出节点单独享有不同的位置嵌入——因为换它们的顺序改变函数(它们是网络对外的可见接口)。

反向边。 在每条前向边的反方向再加一条边(带自己的类型嵌入),消息传递的带宽翻倍,「反传方向」的信息一层就能传过去而不是 $L$ 层。便宜,稳定收益。

下游任务:等变性能换到什么

1. 预测泛化能力

设定:拿一组训好的网络(每个都知道测试精度),训 GNN 直接从权重回归测试精度。

只看权重就预测泛化

图 5 直观地体现了两条路线的差别:等变版的预测点紧贴 $y = x$,而把权重展平喂 MLP 的版本则是一团飘忽的散点。等变性让模型不必把样本效率浪费在「同一个函数的不同写法」上,可以专心去拟合真正和泛化相关的量——谱、层间对齐、sharpness 代理之类。

2. 按行为给网络分类

同样的流水线,把头换成分类:预测某个网络是哪个数据集 / 哪个任务 / 哪个优化器训出来的。一个有意思的现象是,为某个分类任务学到的图嵌入可以迁移到别的分类任务——GNN 学到的并不是某个特定分类边界,而是一个通用的「网络的特征空间」。

3. 相似模型检索

把整个 model zoo 用 GNN 嵌入一遍,然后用余弦相似度去检索。功能相似的网络(比如两个不同种子训出的 CIFAR-10 分类器)会被嵌到相近的位置,尽管它们在参数空间里的距离基本是随机的。这正是等变性该带来的:嵌入度量是由网络实际算什么诱导出来的,而不是由它任意的参数化方式决定。

4. 通过神经元对齐做模型合并

这里要用等变的节点级嵌入,不是池化后的图级向量。把两个网络的神经元按它们的节点嵌入距离做匹配(匈牙利或最优传输),按对齐结果合并权重。传统的「拿一组探测输入比对激活」做对齐,在这套框架里反倒成了一个特例:探测激活只是 GNN 消费的诸多节点特征之一。

和几种基线放在一起对比

方法等变?跨架构?保留拓扑?
展平权重 + MLP否(维度随宽度变)
权重统计量是(仅不变)否(关系丢失)
在权重矩阵上做 CNN否(平移 $\ne$ 置换)部分部分
DeepSet 风格的逐层置换网络是(绑死一个 $\mathcal{S}$)否(只能跑一种架构)部分
神经图 + GNN(本文)是($S_N$-等变)

论文里的实验整体也对得上这张表:在每种架构上,神经图 + GNN 都能与该架构上专门的方法持平或更好,而且它是唯一一个能同时吃所有架构的方案。

局限与待解的问题

  • 规模。 一个十亿参数级别模型的神经图边数也是十亿级。即便用稀疏 GNN 库,目前这套流程在百万参数量级很舒服,到十亿就吃力了。按层 / 按块做局部神经图是显然的下一步。
  • 架构覆盖。 MLP / CNN / ResNet / Transformer 都有了,但任意计算图(混合专家、动态路由、递归结构)还在外面。
  • 探测设计。 探测输入虽然能学,但用什么样的探测(对抗、随机、分布内、分布外)对哪种下游最优,目前还基本是经验问题。
  • 非对称初始化下的行为。 整套故事预设参数遵循 $\mathcal{S}$ 的���道结构。某些权重共享方案、结构化稀疏可能打破这一假设,需要专门建模。

几点要带走的

  1. 真正要对抗的对称是逐层隐藏置换,不是「整个 $\theta$」。神经图把这个对称结构精准地编进了表示。
  2. GNN 自带置换等变,配上一张拓扑复刻网络结构的图,正确的归纳偏置就自动到位了——不需要任何手工设计的参数共享方案。
  3. 一个模型,多种架构。 由于 $\mathcal{S}$ 是 GNN 等变的更大群 $S_N$ 的子群,同一个训好的 GNN 就能消化所有兼容架构。
  4. 池化只放在最后一步。 整条消息传递保持节点级等变,需要图级标量时再做一次不变池化。
  5. 真正的红利在「把网络当数据点」的元任务上:预测泛化、检索相似模型、合并权重。这些任务以前要么忽略对称性、要么手写一套对称约束,现在都可以省下来。

延伸阅读

Liked this piece?

Follow on GitHub for the next one — usually one a week.

GitHub