Breaking the Softmax Bottleneck:A High-rank RNN Language Model

Content

此篇论文主要完成了:
1.通过数学推导,找到并证明了限制RNN-based LMs的性能瓶颈之一——Softmax Bottleneck问题
2.针对这个瓶颈,提出了一个解决方案—— Mixture of Softmaxes

Introduction

Language Modeling Problem

LM问题在指已知了一个符号(token)序列:
X=(X_1,...,X_T)
的情况下,生成一个Language模型来模拟这个序列出现的概率,即求解P(X)的值:
P(X)=\prod_tP(X_t|X_{<T})
而根据链式法则(chain rule)和马尔科夫假设(Markov Assumption), P(X)的值可以通过求解它对应的联合概率得出:
P(X)=\prod_tP(X_t|X_{<T})=\prod{P(X_t|C_t)}
X_t: 下一个符号的概率分布
C_t:历史序列 / 已经出现的所有Tokens

因此,原始的LM问题就转换成了:在每个时刻t,根据已知的符号序列C_t(History), 求解下一时刻可能输出的符号的概率。 由于输出符号有很多种而可能性,所以这个概率的其实是一个在Vocabulary(或Token Set)上的概率分布。

Standard Approach: RNN based LMs

由于符号序列自带时间属性,而我们需要模拟的也是符号之间的时间依赖关系。因此对于LM问题来说,最标准的,而且state of art 模型都是基于RNN的。以下为一个基于RNN的Language Model的结构简图:

RNN based LMs

其中:
h_t=\sigma(Vh_{t-1}+Ux_t)

o=Wh_t

P(x_t|c_t)=y_t=\frac{e^{o(x_t)}}{\sum_ve^{o(v)}}

首先图片左下角的符号序列 " the cat sat on the " 是我们在此时刻 t 已知的历史序列,即
C_t
。由于输入的每个Token是由one-hot编码表示的,当Vocabulary很大的情况下,这个输入维度会非常的高。因此在处理这种高维输入时,会先使用word embedding matrix ( 图中矩阵U ) 来降低维度&学习词语的内部联系,使输入更有意义。
之后,经过处理的输入会被传送给RNN。基于此时刻 t 的输入和上一时刻RNN的旧隐藏状态
h_{t-1}
, RNN会产生新的隐藏状态
h_t

此隐藏状态可能看作是一个由RNN学习到的,含有下一时刻输出符号的信息的特征。由于输出是一个基于Vocabulary的概率分布,因此我们必须把学习到的这个特征映射回初始的Vocabulary。上图中,
h_t
下的output embedding matrix (W) 就负责这个反映射。应用上,两个embedding 矩阵 U和W是一样的。

Hypothesis&Main Issues

由Anton Maximilian Schäfer and Hans Georg Zimmermann写的Recurrent Neural Networks Are Universal Approximators论文可以得知,RNN的表达力是很强的,它可以模拟逼近任意的非线性动态系统(Universal approximation theorem)。由此作者推测出,基于RNN的LMs的性能瓶颈之一应该是RNN最后使用点乘+softmax操作,即:o=Wh_t ; out = softmax(o)

Mathematical Analysis of LM

Defination

为了能进行数学推导和定量分析来证明这个假设,首先我们需要一个自然语言的数学表达。自然语言L可以表示成N个元组的集合:
L=\{(c_1,P^*(X|c_1)),...,(c_N,P^*(X|c_N))\}
其中:
c_i:代表了语言中的任一个可能的context(history token序列)
P^*(X|c_i):真实的数据分布,即:已知一个历史符号序列(c_i),下一符号在Token集合X上的概率分布
X=\{x_1,x_2,...,x_M\}: 代表了语言L中所有可能出现的符号
N: 所有可能的上下文(符号组合)的数目
至此,LM问题可以转换成如下的数学公式表达:
P_\theta(X|c)=P^*(X|c)
即,给定一个自然语言L,LM需要学习一组参数\theta,基于此组参数的模型可以逼近真实的任一上下文(context)所对应的下一符号概率分布。
若我们使用RNN-based LMs, 那么在network的输出端,我们能从softmax layer 的输出直接得到基于此时刻 t 的下一符号概率分布P_{\theta}(X|c) :
P_\theta(X|c)=\frac{exp(h^T_cw_x)}{\sum_xexp(h^T_cw_x)}
因此,训练模型的Objective可以用以下等式表达:
P_{\theta}(X|c) = \frac{exp(h^T_cw_x)}{\sum_xexp(h^T_cw_x)}=P^*(X|c)
即,我们使用一个RNN-based LM 来模拟每个可能context下的下一符号概率分布,并且不断优化模型使用的参数\theta,使LM输出的概率分布逼近真实分布。

Matrix Factorization Problem

在数学化表达LM问题后,它的Objective公式还可以通过矩阵分解来做进一步的分析。
P_{\theta}(X|c)的表达式中,h_c^T代表了输入是不同的context(历史序列)的情况下,RNN所对应的不同隐藏状态。此处,可以把所有可能的情况列出,排列组合成一个矩阵:
H_{\theta}=\left[ \begin{matrix} h^T_{c_1} \\ h^T_{c_2} \\ ... \\ h^T_{c_N} \end{matrix} \right]
这个矩阵包含了RNN针对不同的Context的所有可能的隐藏状态。根据此节开篇的假设,自然语言L一共有N种可能的context(即:符号组合序列)。
相似地,公式中的w_x也可以统一成矩阵表达:
W_{\theta}= \left[ \begin{matrix} w^T_{x_1} \\ w^T_{x_2} \\ ... \\ w^T_{x_M} \end{matrix} \right]
W_{\theta}中的每一行代表了语言L中的某一个符号x_i所对应的embedding coefficient,用以把RNN学到的隐藏状态映射回包含X符号集的Vocabulary空间。同样,根据此节开篇的假设,自然语言L一共有M种可能的符号(tokens)。
最后,我们还需要把自然语言L真实的条件概率分布(在各种可能的context下,下一符号的概率分布)用矩阵的方式表达,从而能使用矩阵知识,数学地分析RNN-based LMs。此处假设矩阵A代表了真实条件概率分布P^*(X|c)\log后的结果:
A= \left[ \begin{matrix} \log{P^*(x_1|c_1)} &\log{P^*(x_2|c_1)}&...&\log{P^*(x_M|c_1)}\\ \log{P^*(x_1|c_2)} &\log{P^*(x_2|c_2)}&...&\log{P^*(x_M|c_2)} \\ ...&...&...&... \\ \log{P^*(x_1|c_N)} &\log{P^*(x_2|c_N)}&...&\log{P^*(x_M|c_N)} \end{matrix} \right]
由上公式可知,A包含了context与对应next token的所有可能的组合。

Rank Analysis

在经历了上述对Objective的分析及矩阵转换,RNN-based LM问题事实上可以抽象如下:
\exists\theta,\log(Softmax(H_{\theta}W^T_{\theta}))=A
即,通过学习,我们希望找到一组参数\theta,以它为参数的LM模型(即RNN)可以逼近真实的下一符号概率分布的\log
为了能推导出Softmax存在的瓶颈,首先先要引入一个矩阵操作row-wise\ shift。对一个矩阵 A 进行row-wise\ shift操作,其结果为一个矩阵集合F(A):
F(A)=\{ A+\Lambda J_{N,M}| \Lambda \ is\ diagonal\ and \ R^{{N}\times{N}}\}
其中:
J_{N,M}:维度对应的全1矩阵
\Lambda:对角线元素值任意的对角线矩阵
事实上,row-wise\ shift的作用是把矩阵A中的每行元素上加上任意一个实数,例如如下\Lambda J_{N,M}与矩阵A相加后,A的第 i 行会被加上一个实数a_i
\left[ \begin{matrix} a_1&0&0 \\ 0&a_2&0 \\ 0&0&a_3 \end{matrix} \right]_{\Lambda^{3\times3} } \times{\left[ \begin{matrix} 1&1&1 \\ 1&1&1 \\ 1&1&1 \end{matrix} \right]_{J^{3\times4}}}={\left[ \begin{matrix} a_1&a_1&a_1 \\ a_2&a_2&a_2 \\ a_3&a_3&a_3 \end{matrix} \right]}
而代表真实下一符号概率分布的 \log 矩阵 AA 经由 row-wise\ shift 所得到的矩阵集合 F(A) ,有如下两个特殊性质:
1.所有真实数据分布所对应的logits都包含在了集合F(A)中。

  1. F(A) 中的所有矩阵的秩
    都相似,相差不大于1。
    附--矩阵的秩 :
    -定义: 矩阵中所有线性独立的列的数目和
    -直观解释:如果一个矩阵有着更高的秩,那么说明它有更多的线性独立的列。若把这些列看作是一组 basis vectors ,那么它们所能表达的空间就更复杂,表达能力就更强。即,高秩的矩阵能包含更多的信息量。
    -例子:如果我们把某自然语言L表示成矩阵形式(如上节中的矩阵A),那么此矩阵A天然拥有高秩的性质,例如:
    -它是高度依赖上下文的——“南”后面的符号可以是“京”或者“瓜”,取决于前后文是关于地理的还是农业的。即,在不同的上下文里,下一符号的概率分布会非常不同。
    -并且我们不可能找到一组有限数目的basis vectors,使用此基来表达语言L中的所有Token的关系。

review

由RNN-based LM的结构推导出,它的Objective如下:
P_{\theta}(X|c) = \frac{exp(h^T_cw_x)}{\sum_xexp(h^T_cw_x)}=P^*(X|c)
通过把自然语言表达成矩阵形式,再进行矩阵分解(Matrix Factorization ),LM的目标可以抽象成如下表达。即,LM需要找到一组参数,借由这组参数生成的下一符号概率能无限逼近真实概率:
\exists\theta,\log(Softmax(H_{\theta}W^T_{\theta}))=A
而通过引入矩阵运算符 row-wise shift ,以及此运算产生的矩阵集F(A)的第一个性质,我们可以推出,若RNN-based LM真的能逼近真实概率分布,那么它产生的 logits 必定属于真实概率分布矩阵 Arow-wise shift 结果集合中。即,Objective为如下:
\exists\theta,such \ that,H_{\theta}W^T_{\theta}\in{F(A)}

Problem: Softmax Bottleneck

至此,LM问题的核心变成了研究是否真的存在一组参数\theta,使基于此\theta的LM所产生的logits属于 F(A) ,如下:
\exists\theta,such \ that,H_{\theta}W^T_{\theta}\in{F(A)}
回忆一下,如上公式中:
H_{\theta}\in{R^{N\times{d}}},代表了所有可能的context输入下的对应隐藏状态。
W^T_{\theta}\in{R^{M\times{d}}},代表了语言中所有可能的token所对应embedding coefficient
因此,由线性代数的知识可知,它们乘积的秩应该小于d,即:
rank(H_{\theta}W^T_{\theta})\leq{d}
(相较于自然语言中的context数目N和token数目M,embedding size d显然会小很多)
又由于row-wise shift的第二个性质(即:F(A)中的所有矩阵的秩都相似,相差不大与1)可推导出,若embedding size d有:
d<min_{A^{'}\in{F(A)}}rank(A^{'})
则对应的RNN-based LM 产生的logits不可能属于F(A)。换句话说,此LM不可能找到一组参数\theta,使其能recover真实概率分布A
到底embedding size d能否满足上述不等式呢?我们已知,真实概率分布矩阵A也属于F(A),而且它是高秩的矩阵,其秩最大能和它的context数目相当(10^{5})。而embedding本就是为了精简输入维度而使用的,所以它的维度一般会较小(10^2)。所以显然成立:
d<min_{A^{'}\in{F(A)}}rank(A^{'})
即,RNN-based LM 不可能找到一组参数 \Theta ,使其能recover真实概率分布 A。它只是一个真实概率分布的低秩近似,表达能力不够,因此失去了一些模拟context间依赖性的能力。这也正是性能瓶颈所在。

Sloution for Softmax Bottleneck

Naive Solution

要解决这个瓶颈问题,一个最直观的方法就是提高embedding size d。但是这显然与embedding的目的不符。另一个方法是使用Ngram模型,来避免Softmax的使用。这两种方法都会使总参数数目急剧增加,容易导致过拟合,显然都不可取。

Mixture of Softmaxes

而另一种方法就是使用作者提出的 MoS(Mixture of Softmaxes) 来替代原始的 Softmax 。MoS的公式如下:
P_{\theta}(X|c) = \sum^K_{k=1}\pi_{c,k}\frac{exp(h^T_{c,k}w_x)}{\sum_xexp(h^T_{c,k}w_x)} \ \ \ \ \ \ \ s.t. \ \sum^K_{k=1}\pi_{c,k}=1
由名字可知,Mos便是把多个Softmax按权相加,综合为一个Softmax混合模型。
传统的RNN-based LM的结构如下左图,而基于MoSRMM-LM 位于下图右。由比较可看出,仅在RNNhidden state h_t 以后有所不同。

standard RNN vs. MoS

这两种不同的模型最后产生的下一符号概率分布的
\log
也不同,如下:
\widehat{A}_{MoS}=\log\sum^K_{k=1}\Pi_k\exp(H_{\theta,k}W^T_\theta)

\widehat{A}_{Softmax}=\log\exp(H_{\theta}W^T_\theta)

\widehat{A}_{MoS}
这个优化版本由于引入了按权相加,因此在最后计算完
\log
运算后,与模型产生的logits不再是原本的线性关系,理论上可以达到任意的高秩,因此提升了模型的表达能力。

Experiments

使用MoS的RNN与其他模型在LM问题上的表现对比如下:


result

Drawback

当然,MoS模型也有它的缺憾。由于使用了多个并行的Softmax按权相加,因此它的运算时间是原有模型的数倍。在实践中,其实Softmax Layer的计算是尤其费时的,因此这也算是不小的短板。由下图实验数据可知,MoS模型的计算时间与它所用的Softmax的数目K近似呈线性关系。

drawback

computational time / #softmax

Summary

现在普遍使用的RNN-based LM,由于在最后把RNN输出的隐藏状态h_t乘以了output embedding matrix,并把得到的结果(logits)输入了softmax layer,导致最后整体模型所能模拟的概率分布空间的秩被embedding-size d 所限制。而MoS模型通过引入按权相加的运算打破了原来的线性关系,提高了模型模拟空间的秩。当然,其代价是线性增加的运算时间。

REFERENCES

[1]Zhilin Yang, Zihang Dai, Ruslan Salakhutdinov, William W. Cohen. Breaking the Softmax Bottleneck: A High-Rank RNN Language Model. In ICLR 2018.
[2]Anton Maximilian Schäfer and Hans Georg Zimmermann. Recurrent neural networks are universal approximators. In International Conference on Artificial Neural Networks, pp. 632–640. Springer, 2006.
[3]Tomas Mikolov, Martin Karafiát, Lukas Burget, Jan Cernocky, and Sanjeev Khudanpur. Recurrent neural network based language model. In Interspeech, volume 2, pp. 3, 2010.
[4]Stephen Merity, Nitish Shirish Keskar, and Richard Socher. Regularizing and optimizing lstm language models. arXiv preprint arXiv:1708.02182, 2017.

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,294评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,493评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,790评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,595评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,718评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,906评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,053评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,797评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,250评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,570评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,711评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,388评论 4 332
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,018评论 3 316
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,796评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,023评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,461评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,595评论 2 350

推荐阅读更多精彩内容