Interpretable Recurrent Neural Networks Using Sequential Sparse Recovery
摘要
循环神经网络能高效的处理序列数据。但是,RNNs通常被视为一个黑箱子模型,而其内部的结构和参数的学习都是不能解释的。在这篇文章中,我们提出了对于解决序列稀疏恢复问题的基于序列迭代软阈值算法(SISTA)的可解释的RNN,其使用稀疏隐向量对一系列相关观测值进行建模。SISTA-RNN结果的结构是由SISTA的计算结构所定义的,该结果是一个新型的RNN网络结构。而且,作为标准的统计模型,SISTA-RNN得到的权值更容易的解释,其中包括稀疏字典、迭代步骤的大小,和正则化参数。另外,在具体的序列压缩感知任务中,SISTA-RNN使得训练的时间更快,得到的结果也比包括长短时记忆(LSTM)RNNs在内的的传统黑箱RNNs效果要好。
1 前期相关工作的介绍
对特征学习的解释和机器学习模型的输出都是不确定的。主要的困难是深度学习方法的意义,由于高的计算复杂度,深度学习方法能够学习出有效的函数特征图。与其试图直接解释学习特征或着黑箱深度网络训练的结果,倒不如设计基于概率模型推理的深度网络结构。因为神经网络都是通过概率模型的推理来描述的,网络的学习权值和输出都保留了它们基于模型的意义。
对于基于模型解释的构建,一些前人的工作经常出现类似的稀疏模型方法。Gregor和LeCun[1]提出了迭代软阈值算法(LISTA)的稀疏编码,该方法通过学习编码和解码来提高原始的ISTA算法的速度和性能。Rolfe和LeCun[2]根据ISTA算法在稀疏系数非负约束的条件下构造了网络结构。在这些例子中,网络的非线性是通过一个线性单元(ReLUs)[3]来调整的,网络的权值是可解释稀疏编码参数的函数。Kamilov和Mansour[4]从数据集中提高了ISTA的非线性。我们通过对稀疏恢复的序列研究扩展了前人的研究工作。
回顾人类解释RNNs的过去的工作,Karpathy et al.[5]表明了LSTM显示了一些有意义的文本注释。Krakovna和Doshi-Velez[6]通过将隐马尔科夫和LSTMs组合来增加ARNNs的可解释性。与前人的这些工作不同,我们的目标不是人类的可解释性,而是模型的可解释性,这就意味着我们所提出的SISTA_RNN模型与不是基于明确的概率模型的LSTMs所使用的黑箱模型是不一样的。我们希望基于模型的网络是建立人类可解释模型的更好的一个出发点。
同样,对于现存的黑箱RNN,我们的SISTA-RNN模型能提供一个基于模型的解释。SISTA-RNN的一个单一的循环层,等价与另外一个最近提出的一个架构,即单一RNN(uRNN)[7][8],SISTA-RNN模型既不使用单一层的限制也不适用隐含的复杂值。uRNN模型已经显示了在不同的任务场景中要比LSTMs表现的好。本文的组织结构如下。首先,详细阐述了我们的方法是如何基于模型的可解释性深度网络进行设计的。然后,我们回顾了传统的RNNs结构同时提出了我们的SISTA-RNN模型。最后,我们给出了实验和数据并总结了实验结果。
2 可解释性深度网络的建立
传统的黑箱深度网络通过g函数给出输出结果
给定参数theta和输入X。参数theta是从训练集I输入输出对中最小化损失函数f得到的,该优化问题(1)是使用随机梯度下降法求解的。通常求解的theta对于人类来说不能直接的解释也不能作为统计模型的参数。
在本文中我们使用深度展开[9]的思想通过给出公式(2),来解决这个优化问题(1)。正如在公式(1)中,像之前一样 f 作为训练的损失函数,但是现在的 h 是一个通过theta得到的确定推断函数。这个推论函数试图通过解决另外一个优化问题 P ,该函数用通过参数theta与概率模型相联系。注意到,参数theta即包含一部分概率模型的模型参数又包括用来通过推断函数到优化函数P的超参数。因为 h 试图解决与标准概率模型相联系的一个优化问题,它的参数是很容易解释清楚的。例如,我们将theta看成是稀疏字典和正则化参数问题。
3 传统的黑箱RNN模型
在这里我们简单的回顾下传统的RNNs。RNNs经常是由多个层堆积起来创建的功能强大的网络[10]。
4 可解释的SISTA-RNN模型
首先我们先介绍本文所使用的具体概率模型。然后我们展现一个迭代的方法来推论真正的降噪信号,即序列迭代软阈值算法(SISTA),对应与一个特殊类型的RNN架构,该架构规定不同节点之间连接的传统RNN网络。
SISTA-RNN使用下面的概率模型:
这就是说,观察序列的每一个元素
5 实验和结果
我们使用与Asif和Romberg[12]相似的实验计划,它被设计用来测试序列的压缩感知算法。在这些计划中,维度N=128的信号向量y序列是128X28灰度图像的列。因此,时间维度实际上是列索引,而且所有序列的长度都是T=128。所有的图像都是来自Caltech-256数据集。我们将彩色图像都转化为灰度图像,夹出来的中心广场区域,并采用双三次插值将图像调整到128×128的大小。训练数据集包括24485张图片,验证集和测试集都是由3061张图像组成。
6 总结
我们展示了SISTA是如何对应一个概率模型的推理,且能够视为深度循环网络SISTA-RNN。SISTA-RNN模型的训练权重能够在概率模型元素上是解释通的。而且,SISTA-RNN模型比之前的两个黑箱RNN模型在具体的图像压缩感知上表现效果更好。通过这个充满希望的初始结果,我们试图将SISTA-RNN模型应用到其他类型的数据集和将来对基于模型的深度网络帮助人类的理解上。