Li L, Song D, Ma R, et al. KNN-BERT: fine-tuning pre-trained models with KNN classifier[J]. arXiv preprint arXiv:2110.02523, 2021.
摘要导读
预训练模型被广泛应用于利用交叉熵损失优化的线性分类器来微调下游任务,可能会面临鲁棒性和稳定性问题。这些问题可以通过学习表示来改进,即在做出预测时去关注在同一个类簇中表示的相似性,不同类簇之间的差异性。在本文中,作者提出将KNN分类器运用到预训练模型的微调中。对于该KNN分类器,作者引入了一个有监督的动量对比学习框架来学习有监督的下游任务的聚类表示。在大规模数据集和小样本数据机上的文本分类实验和鲁棒性测试都显示了结合将KNN结合到传统的微调过程中会得到很大的提升。
模型浅析
本文中提出了KNN-BERT,利用KNN分类器时,使用以BERT为代表的预训练模型作为文本表示编码器。下面将从KNN分类器的效用和如何为KNN分类器设计文本表示的训练过程两个方面进行介绍。
- KNN分类器
作者将一般的线性分类器与KNN分类器相结合,并使用加权平均logits来作为最终的预测logits。假设编码后的文本表示为,其对应的标签为
,线性分类器
;这里使用由
标记的样本
来代表由余弦相似度选出的
个近邻样本。
KNN对应的logits是一个投票结果,记为KNN。给定权重比重
,最终的得分
可以由如下的形式计算:
是由传统的交叉熵损失进行驱动。KNN的驱动方式将在下面的章节中给出其对应的对比学习框架。
- 用于KNN的对比学习
为了在预训练模型的微调中学习适用于KNN的表示,作者引入了一个监督型对比学习框架,该框架使用标签信息来构建对比学习的正负例样本。类似于info-nce损失,带有监督信息的对比损失定义为如下的形式:
表示包含与给定
具有相同标签的
个样本的集合,
则表示来自不同的样本的集合。这样的损失函数可以缩小
和正例样本之间的差距,并推开
和负例样本。
一般来说,传统的对比学习基本就考虑到这里就可以了。但本文的作者对正例集合的构造给出了一种全新的方式。
考虑到正例样本的多样化,即:他们来自同一个类簇但通过预训练模型的编码他们会拥有不同的语义信息。因此,重要的是要确定哪些正例样本应该用于对比损失的计算,否则,学习到的表示可能不会得到紧密的类簇。因此,作者提出了两个学习表示的目标:1)使得同一个类簇中的样本尽可能紧凑;2)将那些不在同一个类簇中的样本尽可能推远。
根据该目标,下图展示了在对比学习中需要重点关注的两类正例样本:。这部分最不相似的正例样本又被称为hard-positives.
基于这个出发点,从原始的正例集合中选取个最相似的正例
和
个最不相似的正例
,并且只针对这些选好的正例样本来进行表示的更新。作者给出的理由是:计算所有的正例样本可能会破坏与分类表示无关的语义信息;并且可能会影响分类结果,因为类簇级别的正例样本表示可能与锚点样本
有很大的不同。根据选定的正例样本,前面的
可以被重写为:
和
的比例也是一个非常关键的参数。
- 动量对比优化
显然,在对比学习训练过程中,使用大量的负例样本可以帮助更好地采样编码表示的底层连续高维空间。因此,动量对比框架MoCo被用来以基于队列更新策略来考虑大规模的负例样本。在动量对比框架中,包含两个独立的编码器:针对查询(锚点)query的编码器,针对key的编码器。query编码器由来自查询样本的梯度下降来更新,而key编码器则由一个动量的过程来进行更新:和
是编码器,而只有查询编码器
通过反向传播通过梯度更新,并以此来驱动
的更新。
首先将负例表示压入循环队列,只有在队列末尾的样本会通过key编码器进行编码来更新。(注:这种更新是在key编码器经过动量更新之后执行。)通过动量更新过程,对比学习过程可以考虑大量的正负例样本,因为该过程不需要计算所有正负例的梯度。 -
双目标训练
最终的训练损失如下:
在训练的过程中,查询样本和其对应的正例和负例的编码都由BERT中[CLS]token的输出为对应的表示。在微调的过程中,作者将原始的交叉熵损失和对比损失结合到一起进行表示学习。从这里可以看出,用于分类的交叉熵损失是对标签信息的直接利用,而在对比学习中,则是利用标签信息进行正负例的构造,使得学习到的表示更有利于类簇的划分。
部分实验
笔者这里主要关注了最相似正例和最不相似正例选取的数量以及其对应的比例:可以看出的一点是,不同数量的hard-positives对性能的影响是非常重要的。这表明,引入适当数量的hard-positives有利于学习更好的表示。
总体来说,对于基于BERT微调的分类任务,作者引入KNN分类器来提供更加鲁棒的分类预测结果;在该目标的驱动下,为KNN的有效预测设计了对应的对比学习过程。在该过程中,提出了基于类别标签的正例选择方式,并且定义了两种值得关注的正例样本:与查询样本最相似的正例和与查询样本最不相似的正例。接着,引入动量对比框架以构造更多的标签级别的正负例样本对。环环相扣,最终得到了显著的性能提高。
其实笔者对基于队列的负例更新策略不太能get到。可能类似这样,将所有的样本都push进循环队列,然后根据样本标签来判断哪些是可用负例?反正,key编码器也不进行参数的更新,一次用多少也不会增加计算量。(: