论文阅读“miCSE: Mutual Information Contrastive Learning for Low-shot Sentence Embeddings”

Klein T, Nabi M. miCSE: Mutual Information Contrastive Learning for Low-shot Sentence Embeddings[J]. arXiv preprint arXiv:2211.04928, 2022.

摘要导读

本文提出了一个基于互信息的对比学习框架miCSE,显著地提高了在少量句子嵌入方面的先进水平。该方法在对比学习过程中调整了不同视图之间的注意力模式。通过miCSE学习句子嵌入需要加强每个句子的增强视图的结构一致性,使对比自监督学习提高了效率。因此,在小样本领域该方法取得了很好的效果,在全样本的场景下依旧适用。

本文的贡献如下:

  • 通过添加一个attention-level的目标,将结构性信息引入到语言模型中。
  • 引入了注意力互信息(AMI),一种可以提高样本效率的自监督对比学习方法。
方法浅析

该方法旨在在对比学习的方案中利用到句子结构信息。与传统的仅在嵌入空间中的语义相似度水平上进行操作的对比学习相比,该方法在模型中注入了结构信息。这是通过在训练过程中正则化模型的注意力空间来实现的。

符号声明

给定字符串语料库\mathcal{X},其对应的数据集表示为\mathcal{D}=\{x_1, x_2, \cdots,x_{|X|}\}x_i \in \mathbb{N}^n表示含有n个token的序列。在对句子映射时,本文采用的是bi-encoder--f_{\theta},其输入是输入句子的不同类型的增强表示。这里使用v \in \{1,2\}作为增强视图表示的索引值。因此,对batch_size为\mathcal{D}_b进行编码,会得到嵌入矩阵E_v \in \mathbb{R}^{|{\mathbb{D}_b}|\times U}U是嵌入表示的维度。使用Transformer的话,对应于E_v产生的还有其相关联的注意力矩阵W_v。因此,提出的模型联合优化如下损失,以达到对语义和结构的一致性学习:

显然,损失函数中的第一项是语义的对齐,在嵌入表示空间中使用传统的InfoNCE来实现;第二项是在注意力空间中对句法的对齐,不同的是,句法的对齐仅关注正例。

Embedding-level Momentum-Contrastive Learning (InfoNCE)

InfoNCE loss试图在嵌入空间中将正例对拉在一起,同时将负例对分开。具体来说,嵌入的InfoNCE推动每个样本以及相应的增强嵌入表示之间的相似性。对应的损失函数如下:

其中e_i \in E_1+e_i \in E_2分别是对应于x_i的两种不同的嵌入表示。d(x, y)=\text{exp}(sim(x, y)/\tau)sim(\cdot)为余弦相似度。显然,负例的构成包含两种形式:(1)给定batch中除了当前样本之外的样本;(2)存储在\mathcal{Q}中的前序batch中的嵌入表示(这是动量编码器中扩充负样本的常用操作)。

Attention-level Mutual Information (AMI)

首先关于Transformer中的注意力机制的详情这里不再进行赘述。只需要知道在每个注意力头中包含三个矩阵Q,K,V,其中Q,K进行运算会得到注意力权重矩阵W=softmax(f(Q,K)) \in \mathbb{R}^{n \times n},其中f(\cdot)表示缩放点积。最终注意力头的输出为WV。当然在实际的应用中为了得到不同的嵌入子空间,一般会将该注意力操作重复H次,被称为多头注意力机制。在训练编码器的过程中,自注意力张量 W的值会受到随机确定性过程的影响,这种随机性由dropout操作产生。因此,基于结构信息的对齐,则是要最大化W_v=[w_1, w_2, \cdots, w_{|{\mathcal{D}}_b|}]之间的互信息。本文通过四个步骤来正则化注意力空间。

  • Attention Tensor Slicing
    Attention tensor slicing
    在输入上实例化一个Transformer栈会产生一个注意张量W,其中包含Transformer的层数L和自注意力的头数H。显然,这里将w打平成1维的张量是为了在三维空间中更好的展示tensor的shape。因此,对于x_i的某个自注意力头来说,其输出是w_i \in \mathbb{R}^{1 \times n \times n}。考虑到多层和多头,其对应的张量就变成了W \in \mathbb{R}^{L \times H \times n \times n}
    slicing函数的主要作用是将每个输入样本的注意力张量切分为R个元素:
    即:
    其中对于r \in [1, R],每个元素w_i^r \in \mathbb{R}^{n \times n}=(w_{j,k})_{1<=j,k<=n}。如果token个数小于n,则通过填充对不同长度的序列进行补齐,以适应批处理。(这种切分方式的好处是还是保留了句子x_in个token之间的相关关系。只是不知道R的设定是否会对性能产生大的影响。)
  • Attention Sampling
    虽然[PAD]的填充会使得在GPU上可以进行有效的批处理编码,但在查看相关关系时,需要放弃token对[PAD]标记产生的注意力得分。为了适应不同长度的标记化序列,对每个网格单元内的注意得分执行采样操作。采样使用的是多项式分布
    对于有值的s,(1<=s<=n),构造s^2的注意力得分池,每个得分被等概率采样,其余的则概率为0。对于每个w_i^r对应的表示中,采样m个注意力得分构成集合
    注:这里忽略了样本下标i。具体来说,J_r由以下多项式分布产生
    因此,对于同一个切分元素r而言,会使用相同的采样索引J_r
    不同增强视图的采样后表示如下:
  • Attention Mutual Information Estimation
    文章提出使用互信息来衡量不同视图下注意力模式的相似性。具体来说,采用对数正态分布对注意力得分分布进行建模。(torch.Tensor.log_normal_())
    两个正态分布的元组向量(z_1, z_2)的互信息可以写为其相关性函数:
    ρ对应于由z_1z_2计算出的相关系数。
    因此,对于给定样本x_i的第r的切分元素而言,其对应的互信息可以写成
    这里的log(\cdot)函数,用于实现从 Log-Normal到Normal随机变量的转换。对应的实现细节如下:
    实现的伪代码中有三点需要注意:
    (1)在切分注意力tensor的时候,给出的伪代码缺少了对w_i的切分步骤;
    (2)在采样的时候就已经是将n \times nw_i^r打平成1维的进行操作了;
    (3)采样过后的值,应该就是将n^2w_i^r变成了m维的向量,然后进行后续的运算,最终返回互信息值的大小。
  • Mutual Information Aggregation
    为了计算注意力正则化的损失分量,需要聚合整个张量的分布相似性。聚合是通过对批次中每个切片r \in R和每个样本x_i的个体相似性进行平均得到的。给定权重缩放因子\lambda \in \mathbb{R},权重对齐损失写成如下形式:
相关实验设置

(1) InfoNCE中另一种负例的设置:The momentum encoder is associated with a sample queue of size |Q| = 384.
(2)切分中,关于切分个数以及采样个数的设置:From each of the (4 × (H/2)) chunks of pooled attentions, we random sample 150 joint-attention pairs for each embedding of the bi-encoder.

实验设置中关于切分的部分,看上去是将层数L每4个分为一组,而在注意力头数H上则是被分为了两组。


想法真的是很出彩,一般来说,大家都考虑引入额外的网络结构通过语义的嵌入表示来获取结构信息以达到语义和结构的一致性。本文作者利用自注意力计算中的中间步骤作为结构信息从而在不增加额外网络结构的情况下使得PLMs引入了结构信息。

但也会存在一定的疑问,为什么一定要进行切分?直接对每个头每个层的注意力权重进行对比不是更加方便吗?除了减少计算量,是否还有别的说法?如果是想让不同的Transformer和不同的Multi-head之间产生交互的话,采用滑动窗口的方式是不是更好?


最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,133评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,682评论 3 390
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,784评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,508评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,603评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,607评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,604评论 3 415
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,359评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,805评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,121评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,280评论 1 344
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,959评论 5 339
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,588评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,206评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,442评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,193评论 2 367
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,144评论 2 352

推荐阅读更多精彩内容