Klein T, Nabi M. miCSE: Mutual Information Contrastive Learning for Low-shot Sentence Embeddings[J]. arXiv preprint arXiv:2211.04928, 2022.
摘要导读
本文提出了一个基于互信息的对比学习框架miCSE,显著地提高了在少量句子嵌入方面的先进水平。该方法在对比学习过程中调整了不同视图之间的注意力模式。通过miCSE学习句子嵌入需要加强每个句子的增强视图的结构一致性,使对比自监督学习提高了效率。因此,在小样本领域该方法取得了很好的效果,在全样本的场景下依旧适用。
本文的贡献如下:
- 通过添加一个attention-level的目标,将结构性信息引入到语言模型中。
- 引入了注意力互信息(AMI),一种可以提高样本效率的自监督对比学习方法。
方法浅析
该方法旨在在对比学习的方案中利用到句子结构信息。与传统的仅在嵌入空间中的语义相似度水平上进行操作的对比学习相比,该方法在模型中注入了结构信息。这是通过在训练过程中正则化模型的注意力空间来实现的。
符号声明
给定字符串语料库,其对应的数据集表示为,表示含有个token的序列。在对句子映射时,本文采用的是bi-encoder--,其输入是输入句子的不同类型的增强表示。这里使用作为增强视图表示的索引值。因此,对batch_size为进行编码,会得到嵌入矩阵,是嵌入表示的维度。使用Transformer的话,对应于产生的还有其相关联的注意力矩阵。因此,提出的模型联合优化如下损失,以达到对语义和结构的一致性学习:
Embedding-level Momentum-Contrastive Learning (InfoNCE)
InfoNCE loss试图在嵌入空间中将正例对拉在一起,同时将负例对分开。具体来说,嵌入的InfoNCE推动每个样本以及相应的增强嵌入表示之间的相似性。对应的损失函数如下:
Attention-level Mutual Information (AMI)
首先关于Transformer中的注意力机制的详情这里不再进行赘述。只需要知道在每个注意力头中包含三个矩阵,,,其中,进行运算会得到注意力权重矩阵,其中表示缩放点积。最终注意力头的输出为。当然在实际的应用中为了得到不同的嵌入子空间,一般会将该注意力操作重复次,被称为多头注意力机制。在训练编码器的过程中,自注意力张量 的值会受到随机确定性过程的影响,这种随机性由dropout操作产生。因此,基于结构信息的对齐,则是要最大化之间的互信息。本文通过四个步骤来正则化注意力空间。
- Attention Tensor Slicing
slicing函数的主要作用是将每个输入样本的注意力张量切分为个元素: - Attention Sampling
虽然[PAD]的填充会使得在GPU上可以进行有效的批处理编码,但在查看相关关系时,需要放弃token对[PAD]标记产生的注意力得分。为了适应不同长度的标记化序列,对每个网格单元内的注意得分执行采样操作。采样使用的是多项式分布
不同增强视图的采样后表示如下: - Attention Mutual Information Estimation
文章提出使用互信息来衡量不同视图下注意力模式的相似性。具体来说,采用对数正态分布对注意力得分分布进行建模。(torch.Tensor.log_normal_())
两个正态分布的元组向量的互信息可以写为其相关性函数:
因此,对于给定样本的第的切分元素而言,其对应的互信息可以写成
(1)在切分注意力tensor的时候,给出的伪代码缺少了对的切分步骤;
(2)在采样的时候就已经是将的打平成1维的进行操作了;
(3)采样过后的值,应该就是将的变成了维的向量,然后进行后续的运算,最终返回互信息值的大小。 - Mutual Information Aggregation
为了计算注意力正则化的损失分量,需要聚合整个张量的分布相似性。聚合是通过对批次中每个切片和每个样本的个体相似性进行平均得到的。给定权重缩放因子,权重对齐损失写成如下形式:
相关实验设置
(1) InfoNCE中另一种负例的设置:The momentum encoder is associated with a sample queue of size |Q| = 384.
(2)切分中,关于切分个数以及采样个数的设置:From each of the (4 × (H/2)) chunks of pooled attentions, we random sample 150 joint-attention pairs for each embedding of the bi-encoder.
实验设置中关于切分的部分,看上去是将层数每4个分为一组,而在注意力头数上则是被分为了两组。
想法真的是很出彩,一般来说,大家都考虑引入额外的网络结构通过语义的嵌入表示来获取结构信息以达到语义和结构的一致性。本文作者利用自注意力计算中的中间步骤作为结构信息从而在不增加额外网络结构的情况下使得PLMs引入了结构信息。
但也会存在一定的疑问,为什么一定要进行切分?直接对每个头每个层的注意力权重进行对比不是更加方便吗?除了减少计算量,是否还有别的说法?如果是想让不同的Transformer和不同的Multi-head之间产生交互的话,采用滑动窗口的方式是不是更好?