论文信息
项目 | 内容 |
---|---|
作者 | Karol Kurach & Marcin Andrychowicz & Ilya Sutskever |
发表 | ICLR 2016 |
摘要和前言
本文实现了一个可以操作和读取指针的神经网络架构,称为 Neural Random Access Machine 。其特点是可以操作一个可变大小的外部记忆。通过学习需要操作指针才能完成的任务验证其能力,并且发现模型可以解决此类问题并使用链表、二叉树等结构。对于简单的任务,模型可以泛化到任意长度的序列上。在特定的假设下,记忆可以在常数时间内读取。
作者认为,神经网络的进步来源于:结构更深的同时,参数更少,且可训练。 Neural Turing Machine 和 Grid-LSTM 的成功在于深度、短期记忆的大小和参数数量,三者相互独立。
模型
模型描述
模型有 个寄存器,每个寄存器储存一个整数,用 上的分布来表示。控制器不能直接访问寄存器,但可以通过一系列预定义的“模块”(module,或称“门”,gate)来与之交互,举例来说,整数加法,等值测试等等。
因此模块记作 ,且
也就是集合上的一个二元运算。
模型每一时间步上进行:
- 控制器根据寄存器的值取得一些输入
- 控制器更新内部状态(是一个LSTM)
- 控制器输出一个“模糊电路”(fuzzy circuit)的描述。包含输入 ,门 和 个输出
- 寄存器的值被模糊电路的输出覆写
其中电路构成如下:
模块 的输入是控制器从 中选出的。其中:
- 表示当前时间步第 个寄存器储存的值
- 表示当前时间步第 个模块的输出
控制器对输入进行加权平均,决定哪些值作为输入。因此,对于 ,
其中 , 是控制器生成的权重向量。
为使模块接收概率分布输入,并输出一个分布,修改定义如下:
计算完成后,控制器决定哪些结果应该重新存储到寄存器中:
其中 是控制结果储存的权重向量。
每一时间步的开始,控制器接收一些由寄存器决定的输入。朴素的想法可能是将寄存器的值直接作为输入。这样的问题是,如果将整个分布作为输入,模型的参数数量将与 (即寄存器的取值上限)有关。下一节将把 联系到一个外部 RAM 上,因此会妨碍模型泛化到不同的存储大小上。
因此对于每个寄存器,我们只输出一个标量, 。这种设计也有一个优势,即限制控制器得到的输入信息量,强制它使用模块解决问题,而非自己解决。特别地,如果 ,该标量保留了全部的信息。如果 是一个布尔模块的输出,那么它就属于这种情况。例如,不等值测试模块 。
记忆磁带
如果将寄存器初始化为一个输入的序列,在一定时间步后,模型将输出序列产生到寄存器里,那么可以描述一个 seq-to-seq 模型。这种使用方式的缺点在于,无法泛化到长序列上,因为可处理的序列的长度等于寄存器数,而它是一个常数。
因此,设计一个长度为 的记忆磁带,每个位置上是一个记忆单元。每个记忆单元储存一个 的分布。这一内容又可解释为一个磁带上的模糊指针。记忆的准确状态可以用矩阵 来描述。 表示第 个记忆单元存储值 的概率。
模块仅使用两种模块和记忆磁带交互:
-
READ
,接收一个参数作为输入(忽略第二个输入参数),输出记忆磁带该地址上的值。通过与上面类似的方法扩展定义到分布上。具体来讲,对于输入的模糊指针 ,模块输出 -
WRITE
,接收输入指针 和值 ,将指针 处的值替换为 。数学表示是 。其中 是 个 组成的列向量, 表示按元素相乘。
记忆磁带同时也是一个输入/输出通道。记忆初始化成一个输入序列,希望模型将输出写到记忆中。
此外,每个时间步,控制器输出一个结束的概率 。运行在时间步 前没有结束的概率是 ,恰好在时间步 输出结果的概率是 。还有一个超参数,最长时间步数 。如果该步没有结束,模型需要强制输出,即 。
设 表示第 个时间步的记忆矩阵。对于输入输出对 ,其中 ,当记忆被初始化为 时,定义损失函数为 。或者使用对数似然函数定义损失函数,即 。
此外,对于我们考虑的问题而言,输出序列通常比记忆短。我们可以在记忆单元上计算损失函数,因为输出已经被包含在内了。
离散化
在分布上进行计算复杂度很高,比如计算 READ
的时间复杂度是 。人们可能会想(我们在后面用实验证明了)中间值的分布具有很低的熵。在训练过后,我们使用一个离散化的模型进行推理。也就是只选取最有可能的输入,以及输出。具体来讲,就是把上面的 换成在最大值上输出 ,其他位置输出 的向量的函数。
离散化的模型每个寄存器和记忆单元中都储存一个 的整数。因此可以加速。
如果只替换 softmax 的话,寄存器和记忆单元仍可以是分布。根据上下文,此处离散化还包括将所有分布经过一个相同的离散化函数。
对于一个前馈控制器,以及较少数量的寄存器(比如小于20),推理可以进一步加速。因为控制器的输入仅为一些二进制的值,我们可以提前把每种配置都计算出来。
同上,控制器的输入仍可能是 0 到 1 的概率。
实验
训练中使用的技术有 Curriculum Learning [1] 、梯度截断、梯度随机噪声、更新权重后调整分布以使其仍然表示整数的概率分布、对输出的熵过低进行逐步递减的惩罚、限制 计算以防止溢出。
这里介绍一下 Curriculum Learning 。
Continuation Method
为了求解非凸优化问题,我们可以使用 Continuation Method (CM)。基本思想是先计算一个平滑版本的问题,再逐渐降低平滑性。这里利用的直觉是,平滑版本的问题展现了全局特点。这种方法中,需要定义一系列的单参数的损失函数, 。 是一个容易优化的高度平滑的版本, 是我们希望优化的版本。
从抽象的层次来看, CM 也是一系列训练标准。序列中的每一个训练标准都为样本设定了不同的权重,或者更一般地,重新为训练分布设置权重。最初,权重倾向于“简单的”样本,或者那些展示了简单概念的样本。序列中的下一个标准,将越来越提高较难样本的采样概率。序列的末尾,我们在训练样本上均匀采样,因此训练数据的分布就是原始的训练分布。
形式化表示如下:
是表示示例的随机变量(有监督学习中可能是 对), 是学习者最终应该学习到的训练样本分布。 是在 步分给 样本的权重,且 。对应的训练分布即
且使得 ,因此
考虑从 到 的单调递增序列。
定义:如果 的熵递增,则称其为一个 Curriculum 。即
并且
考虑 是有限集上的样例,这一过程对应于增加新的样本。某些实验中,仅仅将训练集划分为简单和完整两步就可以得到提升。另一个极端是随机采样。此时困难样本的概率逐渐增加,直到最后所有样本概率相等,均为 1 。
具体到本篇论文中,以序列的长度或者树的大小作为训练复杂度。
每次训练时,样本从一个由难度 决定的分布中采样得到。每当错误率降低到一定阈值以下,就提高难度,直到最大值。
具体的采样方法是:
首先从一个由 决定的分布中采样得到 :
- 10%: 从所有可能难度中均匀采样
- 25%: 从 中均匀采样,其中 服从每次实验成功概率为 的几何分布
- 65%:
再使用难度为 的样本作为训练样本的训练复杂度。
任务
选取的任务如下:
- Access: Given a value and an array , return .
- Increment: Given an array, increment all its elements by 1.
- Copy: Given an array and a pointer to the destination, copy all elements from the array to the given location.
- Reverse: Given an array and a pointer to the destination, copy all elements from the array in reversed order.
- Swap: Given two pointers , and an array , swap elements and .
- Permutation: Given two arrays of elements: (contains a permutation of numbers and (contains random elements), permutate according to .
- ListK: Given a pointer to the head of a linked list and a number , find the value of the -th element on the list.
- ListSearch: Given a pointer to the head of a linked list and a value to find return a pointer to the first node on the list with the value .
- Merge: Given pointers to 2 sorted arrays and , merge them.
- WalkBST: Given a pointer to the root of a Binary Search Tree, and a path to be traversed (sequence of left/right steps), return the element at the end of the path.
模块
所有的模块都需要事先指定类型和顺序,本次实验中使用的如下:
READ
-
ZERO
-
ONE
-
TWO
-
INC
-
ADD
-
SUB
-
DEC
-
LESS-THAN
-
LESS-OR-EQUAL-THAN
-
EQUALITY-TEST
-
MIN
(a, b)$ -
MAX
(a, b)$ WRITE
实验结果
简单任务
前五个任务被划分为简单任务,因为在训练和测试中均达到了 0 错误率。而且训练结果泛化到序列长度为 50 也是 0 错误率。更进一步地,Copy 和 Increment 被验证可以泛化到任意长度。而对模型进行离散化也不会影响其表现。
让我们分析一下 Copy 的记忆、寄存器以及产生的电路图。
其中电路图是第二步之后的每一步。可以看到此时 r2
储存了转移的长度,每次更新到 r2
自己中,因此保持不变。r3
是累加器,每次进行加一后与 r2
中的较小值存到 r4
中。r4
代表当前读的地址,与 r2
相加后得到写的地址,因此二者通过一次读写完成复制。
因此,每两步(一步 r3
自增存储到 r4
直到与 r2
相等,另一步 r4
实际进行读写)完成一个元素的复制。
可以看到上面的电路持续产生地址常数 0 ,作为写的目的地址。
可以看到 r5
作为读写地址,每次由 r1
递增 1 更新到自己,并实现更新。
可以看到 r3
作为读地址递增,每次用目的地的 2 倍减 r3
减 1 作为写地址(注:实际上只对特定目的地址情况成立)。
困难任务
为了解决困难任务,引入了上面说的很多技术。最终在训练数据中把除了 WalkBST 和 Merge 的错误率调到了 0 。而另两个则调到了 1% 以下。
泛化较好的任务是 Permutation , ListK 和 WalkBST 。离散化则只有 Permutation 没有损失性能。其余的错误率高达 70% 以上。
与已有模型比较
NTM 缺乏将一个指针储存在记忆中的自然的方式。因此作者估计其能完成 Copy 和 Reverse 这样的任务,而难以完成 ListK、ListSearch 和 WalkBST 这样的涉及到指针的任务。
NRAM 的一个特点是缺乏基于内容的寻址,这是有意为之的,目的是加速内存访问。
结论
NRAM 可以解决一些算法类问题。部分解决方法可以泛化到任意序列长度。
参考链接
-
Bengio
Yoshua, et al. "Curriculum learning." Proceedings of the 26th annual international conference on machine learning. ACM, 2009. ↩