Memory Networks 最早是在 2014 年由 FaceBook 提出,和 Attention 机制比较类似,主要用于 QA 系统。常见的 RNN 模型虽然有一定的长期记忆能力,但是记忆能力不足。Memory Networks 会采用可读写的记忆组件 (Memory Slots) 将上下文信息保存下来,然后和后续推断模块一起训练,增强模型的长期记忆能力。本文主要介绍论文 《Memory Networks》的成果 MEMNN。
1. Memory Networks
《Memory Networks》是 FaceBook 2014 年的论文,首次提出了 Memory Networks (MEMNN),可以解决传统模型不能有效利用长期记忆进行推理的问题。Memory Networks 的模型如下图所示。
MEMNN 包含 4 个部分,Input (I),Generalization (G),Output (O),Response (R),每个部分的作用如下:
I : 将输入的句子 x 转成内部的特征表示,包括一些预处理步骤,可以将句子转成稀疏或者稠密的特征向量 I(x)。
G: 把模块 I 得到的内容保存在 memory slots 里,类似于一个数组,一般情况下 MEMNN 只会将新的内容 I(x) 增加到 Memory Slots 中,不更改旧的 memory。
O: Output 模块接收"问题",并从 Memory Slots中找出最相关的 k 个 memory。
R: 根据问题和模块 O 找到的 memory 进行推断,得到答案。
2. Basic Model
现在介绍 Memory Network 中的 Basic Model (基础模型),模块 I 接收输入的句子,包括一些上下文信息,问题的答案可以从这些上下文信息获取。 假设有 N 个输入句子,模块 G 会将句子依次存放到 Memory Slots,分别为 m1, m2, ..., mN。
推断的重点在于模块 O 和 R,在得到问题 x 后,模块 O 要从 Memory Slots 中找出 k 个最相关的 memory,Basic Model 中 k 设为 2,则首先要找出相关性最高的 memory:
接着我们同时利用问题 x 和 mo1 找出相关性第 2 高的 memory:
因此我们得到模块 O 的输出,包括问题 x 和两个 memory:
模块 R 可以从字典中挑选最合适的词 w 作为答案:
打分函数 SO 和 SR 的形式都是如下:
3. Basic Model 训练
Basic Model 的训练采用 margin ranking loss 和随机梯度下降 SGD,loss 函数公式如下:
gamma 代表 margin,loss 包含了三个部分:
第一项: 使 x 和 mo1 的得分减去任意别的 memory 差值大于 gamma
第二项: 使 x、mo1 和 mo2 的得分减去任意别的 memory 差值大于 gamma
第三项: 使单词 r 的得分减去任意其他单词的得分大于 gamma
4. 实例
如图中的例子,上下文句子包括前两行的所有句子,这些句子都保存在 Memory Slots 中。
在回答第一个问题 "Where is the milk now?",MEMNN 首先要在 Memory 中找到最相关的句子 m1
m1 = "Joe travelled to the office"
然后根据问题 和 m1 找到下一个相关句子 m2
m2 = "Joe left the milk"
最后通过模块 R 得到答案 "office"。
5. 参考文献
《Memory Networks》