bert原码解析(embedding)

       写这篇文章的起因是看ALBERT的时候,对其中参数因式分解,减少参数的方式不理解,后来通过原码来了解原理。后来想到虽然平时基于bert的nlp任务做的挺多的,但对原理还是一知半解的,所以在此记录。后续有时间的话,将常见的,看过的论文做个总结,不然容易忘记。(attention is all your need,bert,albert,roberta,sentence -bert,simcse,consert,simbert,nezha,ernie,spanbert,gpt,xlnet,tinybert,distillbert)


图一

从图一可以明显看出,bert主要分为三块。embedding层,encoder层,以及pooler层,本章为embedding层的原码分析。



input embedding


input embedding原码

可以看出,输入的input,会先经过tokernizer,会补上cls,sep等特殊字符。然后embedding层会获取句子的token embeddings+segment embeddings+position embeddings作为最终的句子embedding。

1 token embedding:


token embedding

token embedding有两种初始化方式。如果是训练预训练,随机出初始化一个30522*768的lookup table(根据wordpiece算法,英文一共有30522个sub-word就可以代表所有词汇,每个sub-word 768纬)。如果是在预训练模型的基础上finetune,读取预训练模型训练好的lookup table。假设输入的句子经过tokernized长度为16。经过lookup table就是16*768维的句子表示。

2 position embedding:


position embedding

position embedding的lookup table 大小512*768,说明bert最长处理长度为512的句子。长于512有几种截断获取的方式。position embedding的生成方式有两种:1 根据公式直接生成 2 根据反向传播计算梯度更新。其中,transformer使用公式直接生成,公式为:

position embedding生成方式1

其中,pos指的是这个word在这个句子中的位置;2i指的是embedding词向量的偶数维度,2i+1指的是embedding词向量的奇数维度。为什么这个公式能代表单词在句子中的位置信息呢?因为位置编码基于不同位置添加了正弦波,对于每个维度,波的频率和偏移都有不同。也就是说对于序列中不同位置的单词,对应不同的正余弦波,可以认为他们有相对关系。优点在于减少计算量了,只需要一次初始化不需要后续更新。

其中, bert使用的是根据反向传播计算梯度更新。

3 segment embedding:


token embedding size

bert输入可以为两句话。[cls]....[seq]....[seq]。每句话结尾以seq分割。从embedding的大小可以看出,lookup table由两个768组成,对应第一句和第二句。该参数也由训练得到。

4 LN以及dropout:


embedding形成

embeddings = dropout(layernorm(token embeddings+segment embeddings+position embeddings))。Normalization 有很多种,但是它们都有一个共同的目的,那就是把输入转化成均值为 0 方差为1的数据。我们在把数据送入激活函数之前进行normalization(归一化),因为我们不希望输入数据落在激活函数的饱和区,发生梯度消失的问题,使得我们的模型训练变得困难。这里不使用bn可以去除batch size对模型的影响。

下一篇为bert核心encoder模块的解析。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容