论文笔记| 图神经网络的测试阶段训练(Test-Time Training, TTT)

论文标题:Test-Time Training for Graph Neural Networks
论文链接:https://arxiv.org/abs/2210.08813
论文来源:arxiv,2022
作者单位:Michigan State University the United States, Microsoft Research Asia China, Dalian University of Technology China, University of Notre Dame the United States

摘要

  图神经网络(GNN)在图分类任务中已取得了重大进展。然而,训练集和测试集的性能表现往往具有较大差异。为此,本文首次将test-time training的思路引入了GNN以增强模型的泛化性能。特别地,本文设计了一种新颖的基于自监督学习的test-time training策略根据测试图样本对模型进行调整。

背景与动机

  现有的图分类模型都是遵循训练-测试的范式,然而训练集和测试集的性能表现往往具有较大差异,如下图所示。


  作者解释其中的原因在于,图数据是高维空间中的非欧式结构数据,可能会有无穷并且多样的拓扑结构。如下图,来源于相同数据集的两个图数据样本具有相同的标签,但却具有极其不同的结构模式。显然,很难直接从左图将知识迁移到右图。


  然而如何对图神经网络所处理的图数据设计一种test-time training的框架具有挑战性。首先,如何针对图分类任务设计一种有效的自监督学习任务?图数据通常不仅仅包含节点属性信息,也包含丰富的结构信息。因此将这两种图信息融入到图自监督任务中极其重要。其次,如何缓解在test-time的潜在的特征失真?由于模型被一个单独测试样本所调整,表示空间可能朝着某个方向严重扭曲,这是在main task中所不希望见到的。

  为此本文提出了Graph Test-Time Training with Constraint(GT3),一个test-time training的框架来提高GNN模型在图分类任务上的表现。随后将对模型各部分进行介绍。

方法

整体框架

  假设所使用的图神经网络具有L层,第l层的可学习参数表示为\theta_{l}。总体的模型参数可以表示为\theta_{\text {main }}=\left(\theta_{1}, \theta_{2}, \ldots, \theta_{L}\right),其中\delta_{m}是在预测层的参数。本文将图分类任务视为主任务(main task)。对于所给的图的集合\mathcal{G}=\left(G_{1}, G_{2}, \ldots, G_{n}\right)和它对应的图的类标签\boldsymbol{y}=\left(y_{1}, y_{2}, \ldots, y_{n}\right),主任务的目标是:
\min _{\theta_{\text {main }}} \frac{1}{n} \sum_{i=1}^{n} \mathcal{L}_{m}\left(\mathrm{~A}_{i}, \mathrm{X}_{i}, y_{i} ; \theta_{\text {main }}\right), G_{i} \in \mathcal{G}
其中\mathcal{L}_{m}(\cdot)代表主任务的损失函数。

  此外,自监督学习任务(SSL)在test-time training的框架中也起到了重要的作用,一个所期望的SSL任务应该和主任务协同一致,确保无监督的test-time training能够带来分类性能的提升。本文提出了一个层级化的SSL任务来捕获潜在的节点-节点和节点-图的关联。SSL任务的目标记为\mathcal{L}_{s}(\cdot),整个GT3框架由三个阶段组成:(1)训练阶段;(2)test-time 训练阶段;(3)测试阶段。

训练阶段

  对于所给训练集\mathcal{G},训练阶段要根据主任务和SSL任务来对GNN模型中的参数进行学习,而主要的挑战在于如何融合这两个任务。作者首先利用两个独立的GNN来分别在主任务和SSL任务中训练,实验结果显示这两个GNN模型在前几层提取的特征极其相似。因此,本文在网络前几层使用了一个共享参数的GNN:\theta_{e}=\left(\theta_{1}, \theta_{2}, \ldots, \theta_{K}\right),其中K \in\{1,2, \ldots, L\}。共享的部分称为共享的图特征抽取器(shared graph feature extractor)。针对主任务部分的模型参数表示为\theta_{m}=\left(\theta_{K+1}, \ldots, \theta_{L}, \delta_{m}\right),称为图分类头(graph classification head)。相应的,SSL任务有对应的自监督学习头(self-supervised learning head),记为\theta_{s}=\left(\theta_{K+1}^{\prime}, \ldots, \theta_{L}^{\prime}, \delta_{S}\right),其中\delta_{S}是SSL任务的预测层。进一步总结,主任务的模型参数表示为\theta_{\text {main }}=\left(\theta_{e}, \theta_{m}\right),SSL任务的模型参数表示为\theta_{\text {self }}=\left(\theta_{e}, \theta_{s}\right)。总体框架的模型参数表示为\theta_{\text {overall }}=\left(\theta_{e}, \theta_{m}, \theta_{s}\right)。在训练阶段,整个框架的优化目标是最小化主任务和SSL任务损失的加权:
\min _{\theta_{e}, \theta_{m}, \theta_{s}} \frac{1}{n} \sum_{i=1}^{N} \mathcal{L}_{m}\left(\mathrm{~A}_{i}, \mathrm{X}_{i}, y_{i} ; \theta_{e}, \theta_{m}\right)+\gamma \mathcal{L}_{s}\left(\mathrm{~A}_{i}, \mathrm{X}_{i} ; \theta_{e}, \theta_{s}\right), G_{i} \in \mathcal{G}

Test-time Training阶段

  对于一个测试图样本G_{t},test-time training过程将利用SSL任务对之前学习到的模型进行微调。具体地,对于一个测试图样本G_{t}=\left(\mathrm{A}_{t}, \mathrm{X}_{t}\right),共享参数的图特征提取器(shared graph feature extractor)和SSL头(self-supervised learning head) 通过最小化SSL任务的损失来进行微调:
\min _{\theta_{e}, \theta_{s}} \mathcal{L}_{s}\left(\mathrm{~A}_{t}, \mathrm{X}_{t} ; \theta_{e}, \theta_{s}\right)

  假定\theta_{\text {overall }}^{*}=\left(\theta_{e}^{*}, \theta_{m}^{*}, \theta_{s}^{*}\right)是上一阶段(训练阶段)得到的模型最优参数。在本阶段(Test-time Training)阶段,对于所给的测试图样本G_{t},所提框架将\theta_{e}^{*}\theta_{s}^{*}更新为\theta_{e}^{*^{\prime}}\theta_{s}^{*^{\prime}}

测试阶段

  利用test-time training更新后的共享参数的图特征提取器和训练阶段得到的图分类头(即主任务模型)对测试样本G_{t}预测其类标签:
\hat{y}_{t}=f\left(\mathrm{X}_{t}, \mathrm{~A}_{t} ; \theta_{e}^{*^{\prime}}, \theta_{m}^{*}\right)

GT3中的自监督学习任务

  之前的文献指出一个合适且informative的SSL任务是test-time training成功的关键。然而为GNN设计合适的SSL任务具有挑战性。首先,图数据和图像数据本质不同,图像领域常用的旋转不变性(rotation invariance)在图领域不存在,这使得现有研究中常用的SSL任务在图数据上不可用。其次,现有的针对图分类任务所设计的SSL任务都是基于多个图样本设计的(如Infograph和MVGRL)。这里作者给出的解释我不是很理解:

  These are not applicable for the test-time training where we aim at specifically adapting the model for every single graph during the test time.

为什么每次只能用单个测试样本来进行Test-time Training?
  接上文,第三个原因是,图数据既包括节点属性信息也包括拓扑结构信息,在设计SSL任务中同时兼顾这两种图的信息很重要。

  受到图领域对比学习的启发,本文提出了一个层级化的SSL任务,包括一个局部的角度和一个全局的角度,来从节点-节点级别和节点-图级别充分利用图信息。此外,所提出的SSL任务并不是根据不同图之间的差别建立的,因此它能够应用到单一的图。后续实验验证了全局对比学习(节点-图级别)和局部对比学习(节点-节点级别)具有互补特性,共同提升所提方法的性能。

全局对比学习

  全局对比学习的目标是帮助节点表示捕获整个图的全局信息。本方法的思路是最大化节点表示和全局图表示的互信息。具体地,对于输入图样本G,利用数据增强生成两种不同的视角:1.原始图\text { View }_{0};2.随机shuffle图中所有节点特征属性的视角\text { View }_{1}。通过这两种图视角,我们可以根据共享的GNN特征提取器得到两个相应的节点特征Z_{0}Z_{1}。随后利用视角\text { View }_{0}通过一个多层感知机(MLP)来得到全局图表示g_{0}(这里为什么不对节点表示Z_{0}直接做池化,比如infograph中使用了meanpooling)。

  在正负样本的构建上,来源于原图\text { View }_{0}的节点表示和全局图表示被视为正样本,来源于扰动视角\text { View }_{1}的节点表示和全局图表示被视为负样本。利用一个判别器\mathcal{D}来计算每个节点表示-图表示对的相似度得分,正样本的得分应该更高,负样本的得分应更低。\mathcal{D}\left({Z}_{s i}, \mathbf{g}_{0}\right)=\operatorname{Sigmoid}\left({Z}_{s i} * \mathbf{g}_{0}\right),其中{Z}_{s i}代表从视角\text { View }_{s}得到的节点表示,*代表内积。全局对比学习的总目标函数定义为:
\mathcal{L}_{g}=-\frac{1}{2 N}\left(\sum_{i=1}^{N}\left(\log \mathcal{D}\left({Z}_{0 i}, \mathrm{~g}_{0}\right)+\log \left(1-\mathcal{D}\left({Z}_{1 i}, \mathrm{~g}_{0}\right)\right)\right)\right.

局部对比学习

  作者阐述了在全局对比学习之外,还需要局部对比学习的原因:

  In global contrastive learning, the model aims at capturing the global information of the whole graph into the node representations. In other words, the model attempts to learn an invariant graph representation from a global level and blend it into the node representations. However, graph data consists of nodes with various attributes and distinct structural roles. To fully exploit the structure information of a graph, in addition to the invariant graph representation from a global level, it is also important to learn invariant node representations from a local level.

  对于所给的一张输入图G,利用数据增强另外生成两种视角。采用的数据增强策略分别是:自适应的丢弃边自适应的属性掩盖。边丢弃的概率由边在图中重要性确定,越重要的边丢弃概率越小,节点掩盖的概率由节点属性每个维度的重要性确定,越重要的维度掩盖的概率越小。边的重要性通过相连节点的度的平均确定,对于节点属性的重要性:对于每个属性维度,首先计算该维度在整个图中所有节点上的规范化值(可能是归一化的),然后将其乘以每个节点的度,最后对所有节点的结果取平均值。

  将两种数据增强图输入到共享参数的GNN特征提取器,输出的两个节点表示矩阵分别记为Z_{2}Z_{3}。局部对比学习的目标是区分两个增强视图中的节点是否是输入图中的同一个节点。因此\left({Z}_{2 i}, {Z}_{3 i}\right)(i \in\{1, \ldots, N\})代表正样本对,N代表图中节点的数量。\left({Z}_{2 i}, {Z}_{3 j}\right)\left({Z}_{2 i}, {Z}_{2 j}\right)(i, j \in\{1, \ldots, N\} \text { and } i \neq j)分别代表intra-view的负样本对和inter-view的负样本对(GRACE图对比学习框架)。随后根据正负样本对建立InfoNCE对比学习损失:
{I}_{c}\left({Z}_{2 i}, {Z}_{3 i}\right)=\log \frac{h\left({Z}_{2 i}, {Z}_{3 i}\right)}{h\left({Z}_{2 i}, {Z}_{3 i}\right)+\sum_{j \neq i} h\left({Z}_{2 i}, {Z}_{3 j}\right)+\sum_{j \neq i} h\left({Z}_{2 i}, {Z}_{2 j}\right)}

其中h\left({Z}_{2 i}, {Z}_{3 j}\right)=e^{\cos \left(g\left({Z}_{2 i}\right), g\left({Z}_{3 j}\right)\right) / \tau}g()是两层感知机(投影头)增强模型的表达能力。经过此MLP之后得到的节点表示分别记为{Z}_{2}^{\prime}{Z}_{3}^{\prime}

  除上述对比学习目标以外,局部对比学习模块在总目标函数中额外融入了一个去相关正则化项(decorrelation regularizer),目的是:

  a decorrelation regularizer has also been added to the overall objective function of the local contrastive learning, in order to encourage different representation dimensions to capture distinct information. The decorrelation regularizer is applied over the refined node representations {Z}_{2}^{\prime} and {Z}_{3}^{\prime} as follows:
{I}_{d}\left({Z}_{2}^{\prime}\right)=\left\|{Z}_{2}^{\prime}{ }^{T} {Z}_{2}^{\prime}-I\right\|_{F}^{2}

这里没有看懂,作者没给出过多解释,也没有提供相应的参考文献。
  最终的局部对比学习损失定义如下:
{L}_{l}=-\frac{1}{2 N} \sum_{i=1}^{N}\left({I}_{c}\left({Z}_{2 i}, {Z}_{3 i}\right)+{I}_{c}\left({Z}_{3 i}, {Z}_{2 i}\right)\right)+\frac{\beta}{2}\left({I}_{d}\left({Z}_{2}^{\prime}\right)+{I}_{d}\left({Z}_{3}^{\prime}\right)\right)

对比学习总损失

GT3框架中SSL任务的总损失由全局对比学习损失和局部对比学习损失共同构成:
\mathcal{L}_{s}=\mathcal{L}_{g}+\alpha \mathcal{L}_{l}

适应性约束(Adaptation Constraint)

  之前的研究曾指出,直接利用test-time training可能会造成严重的表示失真,因为模型可能对于一个特定的测试样本的SSL任务过拟合,这会影响模型在主任务上的表现。为了缓解这种问题,本文在test-time training的过程中增加了一种适应性约束(adaptation constraint)。核心思路是将共享参数的图特征提取器(shared graph feature extractor)对训练样本和测试样本所生成的特征表示进行对齐。以这种方式,生成的测试图样本的特征表示被约束为和训练样本的特征表示尽可能相近。

  共享参数的图特征提取器由K层GNN组成,我们用\left\{\mathbf{H}_{1}^{K}, \mathbf{H}_{2}^{K}, \ldots, \mathbf{H}_{n}^{K}\right\}代表特征提取器对训练样本所输出的节点特征表示。在训练阶段完成后,我们先利用一个readout函数来获取全图的特征表示\mathbf{h}_{i}^{K}=\operatorname{READOUT}\left(\mathbf{H}_{i}^{K}\right),随后统计图特征表示的两个重要分布特征:均值(empirical mean)\mu=\frac{1}{n} \sum_{i}^{n} \mathbf{h}_{i}^{K}和协方差矩阵(covariance matrix)\Sigma=\frac{1}{n-1}\left(\mathbf{H}^{K^{T}} \mathbf{H}^{K}-\left(\mathbf{I}^{T} \mathbf{H}^{K}\right)^{T}\left(\mathbf{I}^{T} \mathbf{H}^{K}\right)\right),其中\mathbf{H}^{K}=\left\{\mathbf{h}_{1}^{K^{T}}, \ldots, \mathbf{h}_{n}^{K^{T}}\right\}。在test-time training阶段,对于一个输入的测试图G_{t}及其增强视角,我们同样统计这两个分布特征记为\mu_{t}\Sigma_{t}。自适应约束项迫使测试图样本的特征分布与训练样本的特征分布一致:
\mathcal{L}_{c}=\left\|\mu-\mu_{t}\right\|_{2}^{2}+\left\|\Sigma-\Sigma_{t}\right\|_{F}^{2}

理论分析(Theoretical Analysis)

  作者在理论上证明了test-time training对GNN在图分类任务上的表现是beneficial的,感兴趣的读者请自行查阅原论文。

实验

实验设置

  为了模拟测试数据和训练数据分布不一致的情况,数据集DD,ENZYMES,PROTEINS的图样本根据它们的图大小(节点数量)被分为了两组。在小规模图的组选取80%的图作为训练集,剩余的20%作为验证集。测试集由大规模图的组随机选取并保证数量与验证集一致。对于ogbg-molhiv采用官方的划分方式。以上数据划分方式被记为OOD数据划分,训练集、验证集、测试集的表现如下表所示:

  本文比较了三种方法:(1)RAW模型只在主任务上训练;(2)JOINT模型在主任务和所提的层级化SSL任务上联合训练;(3)GT3本文所提方法,模型额外通过test-time training训练。在OOD数据划分实验设置下的结果对比如下图所示:

GT3与RAW对比的性能增益如下表所示:


在随机数据划分的实验设置下(8:1:1)结果对比如下图所示:


消融实验

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容