Neural Ordinary Differential Equations 神经常微分方程

0 摘要

我们引入了一个新的深度神经网络模型家族. 我们没有用非连续的隐藏层, 而是用神经网络把隐状态的导数参数化. 网络的输出是通过黑盒微分方程求解器来计算的. 这些连续层的网络的内存消耗是稳定不变的, 针对每个输入来设计估计方法的话, 就能做计算精度和计算速度的权衡. 我们通过连续层的ResNet和连续时间的隐变量模型展示了这些特性. 我们还构造了连续正则化流, 该生成模型可以直接用极大似然来训练, 不需要对数据维度进行分区或者排序. 我们展示了训练过程其实不需要了解ODE求解器内部的实现, 也能对ODE求解器的反向进行计算. 这就允许我们构造大规模模型, 并进行端到端的训练.

1 引言

像ResNet, RNN解码器, 正则化流, 它们都组合了隐状态的一系列变换, 构建出一个复杂的变换, 如下:

image.png

其中t属于[0…T], ht属于Rd, 这些迭代更新可以看作是连续变换的欧拉离散化. 当t趋于0, step趋于无穷时, 可以得到如下的常微分方程(ODE, ordinary differential equation):

image.png

给定h(0), 我们可以把h(T) 作为该方程在T时刻的解. 该解可以用黑盒ODE求解器计算得到, 求解器还能根据需要的精度自行决定在何处对f进行拟合. 图1对比这一过程:

image.png

图1

左: ResNet定义了一系列非连续的有限转换.

右: ODE网络定义了一个向量场, 可以对隐状态进行连续的转换.

黑点表示估计点.

定义 使用ODE的模型有如下好处:

内存优化:

在第2节, 我们展示了如何在不涉及ODE求解器黑盒内部操作的情况下, 对任意ODE求解过程求反向, 得到标量损失的梯度. 不储存任何前向计算结果, 就可以让我们在内存占用不变的情况下训练任意深度的模型. 这就解决了深度神经网络模型训练的主要瓶颈---模型深度.

自适应计算法

欧拉法求ODE是比较古老的方法了, 现代ODE求解器可以做到根据误差精度要求来调整求解过程, 监控误差来获得需要的精度. 这就可以根据问题复杂度来调整模型估值的消耗. 在模型训练结束后, 还能降低计算精度来满足程序实时性的要求.

可拓展和可逆的标准化流

连续变换带了一个意想不到的好处, 变量方程式的变化更加容易计算了. 在第4节, 我们提出这个结论并组建了一个可逆的密度模型, 该模型可以避免正则化流中单单元的瓶颈, 可以直接用极大似然来进行训练.

连续时间序列模型

RNN需要离散的观测和发射间隔, 而定义连续的模型可以接收任意时间得到的数据. 此种模型的构建和展示详见第5节.

2 ODE求解器的反向自动微分

训练连续层网络的主要问题就是对ODE求解器的反向微分(也叫反向传播). 直接根据求解器内部操作来求微分的内存占用过大, 并且会引入额外的误差.

我们把ODE求解器当做黑盒, 用”伴随灵敏度法”(adjoint sensitivity method)来求梯度. 这种计算法是通过计算另一个参数化的ODE来实现的. 这种方法的复杂度会根据问题的规模线性变化, 内存占用也很低, 并且可以显式的控制计算精度.

假设标量的损失函数为L, 输入是ODE求解器的结果:

image.png

为最小化L, 就需要求L对θ的梯度, 第一步就是要求L在每一个时刻对隐状态z(t)的梯度. 这部分被称为”伴随”:

image.png

它也是一个ODE, 可以视作瞬时的链式法则:

image.png

这样, 再调一次求解器就可以解出
image.png

. 这个求解是反向进行的, 初始状态是
image.png

解这个ODE就需要知道从t0到t1轨迹上的所有z(t). 所以在求伴随的过程中需要把z(t)也一并解出, 就可以在中间的轨迹上使用z(t)的值来求a(t)了.

计算L对θ的偏导则需要求第三个积分式:

image.png

这个式子需要知道z(t)和a(t)的值.

image.png

image.png

这两个向量-jacobian 乘积可以通过一次自动微分直接得到, 时间消耗跟对f的估值差不多. 只要把初始状态, 伴随和另一个偏导 concat 到一个向量中, 所有求解z,a和
image.png

的积分, 都可以通过调用一次ODE求解器计算得出. 如下算法1的伪代码:
image.png

大多数的ODE求解器都可以输出中间计算结果z(t), 当loss取决于这些中间状态时, 反向偏导的计算也必须拆成一系列的求解. 如图2所示:


image.png

图2: ODE求解器的反向过程.

伴随敏感度法求反向是分时刻实时求解的. 参数化的系统包括了初始状态以及loss对状态的灵敏度. 如果损失直接依赖于多个时刻的隐状态的观测, 伴随状态也必须在loss对观测的偏导方向上更新.

在每个观测处, 伴随都必须跟着偏导
image.png

的方向调整.

在附录C中给出了L关于t0, t1偏导的解法. 附录B中给出上面公式的详细推导过程. 附录D给出了上述算法scipy实现, 这部分代码也支持更高阶的微分.

https://github.com/rtqichen/torchdiffeq中还给出了pytorch版本的实现.

3 用ODE来取代ResNet进行有监督的训练

本节尝试用神经ODE进行有监督训练.

软件: (作者说自己选取了某某ODE求解器, 还用一个第三方框架实现了求反向, 但是在pytorch版代码中这些都对不上)

模型结构: 使用了一个小的残差网络, 对输入进行了2次下采样, 然后叠了6个标准残差链接层, 这6个残差连接层替换成ODE求解器模块. 还测试了一下同样结构, 但是反向直接用链式法则求解的网络, 记为RK-Net. 各网络的表现如下:

image.png

可以看到, ODE网络和RK网络可以达到和ResNet相同的性能.

ODE****网络的误差控制: ODE求解器可以保证计算误差在真实解的某个误差限内. 更改这个误差限会改变网络的性能表现. 图3a展示了误差是可控的. 图3b展示了前向计算时间是跟着函数估值次数成比例增加的. 所以降低误差限可以在计算速度和精度之间做取舍. 你可以在训练时用高精度, 但是在推理时用低精度来加快速度..

image.png

图3c表明: 反向计算的消耗只有前向计算的一半左右. 这就表明, 伴随法不但节省内存, 还比直接求反向更加高效.

网络深度: 在ODE中不太好直接定义网络层数这个概念. 有点类似的是隐状态方程估值所需的次数, 这依赖于ODE求解器的输入和初始状态. 图3d展示了训练过程中估值次数的增加, 这对应了模型复杂度的增长.

4 连续正则化流

还有一个模型也出现了类似式1的非连续型方程, 那就是正则化流(NF, normalization flows)和NICE framework. 这些模式使用变量代换定理来计算可逆变换之后的概率密度.


image.png

经典的正则化流模型: planar normalization flows的公式如下:

image.png

一般来说, 使用变量代换公式的瓶颈是计算雅克比矩阵
image.png

, 它的计算复杂度要么是z维度的立方, 要么是隐藏单元数量的立方. 最近的研究都是在NF模型的表达能力和计算复杂度做取舍.

令人惊讶的是, 我们把非连续的模型公式, 用第3节同样的思路来转换成连续模型可以减少计算量.

定理1: 变量瞬时变化

设z(t)是一个有限连续随机变量,概率p(z(t))依赖于时间. 则下式是z(t)随时间连续变化的微分方程:

image.png

假设f在z上均匀Lipschitz连续,在t上连续,那么对数概率密度的变化也遵循微分方程:


image.png

证明见附录A. 与式6的log计算不同, 本式只需要计算迹(trace)的操作. 另外, 不像标准的NF模型, 本式不要求f是可逆的, 因为如果满足唯一性,那么整个转换自然就是可逆的.

应用变量瞬时变化定理,我们可以看一下planar normalization flows的连续模拟版本:

image.png

给定一个初始分布p(z(0),我们可以从p(z(T))中采样,并通过求解这组ODE来评估其概率密度。

使用多个线性成本的隐藏单元

当det(行列式)不是线性方程时, 迹的方程还是线性的, 并且满足:

image.png

这样我们的方程就可以由一系列的求和得到, 概率密度的微分方程也是一个求和:

image.png

这意味着我们可以很简便的评估多隐藏单元的流模型,其成本仅与隐藏单元M的数量呈线性关系。使用标准的NF模型评估这种“宽”层的成本是O(M3),这意味着标准NF体系结构的多个层只使用单个隐藏单元.

依赖于时间的动态方程

我们可以将流的参数指定为t的函数,使微分方程f(z(t)、t)随t而变化。这种参数化的方法是一种超网络. 我们还为每个隐藏层引入了门机制:

image.png

其中:
image.png

, 是一个神经网络, 可以学习到何时使用fn. 我们把该模型称之为连续正则化流(CNF, continuous normalizing flows)

4.1 CNF试验

我们首先比较连续的和离散的planar正则化流在学习样本从一个已知的分布。我们证明了一个具有M个隐藏单元的连续 planar CNF至少可以与一个具有K层(M = K)的离散 planar NF具有同样的拟合能力,某些情况下CNF的拟合能力甚至更强.

拟合概率密度

设置一个前述的CNF, 用adam优化器训练10000个step. 对应的NF使用RMSprop训练500000个step. 此任务中损失函数为KL (q(x)||p(x)), 最小化这个损失函数, 来用q(x)拟合目标概率分布p(x). 图4表明, CNF可以得到更低的损失.

[图片上传失败...(image-7d47a5-1616472352555)]

极大似然训练

CNF一个有用的特性是: 计算反向转换和正向的成本差不多, 这一点是NF模型做不到的. 这样在用CNF模型做概率密度估计任务时, 我们可以通过极大似然估计来进行训练 也就是最大化log(q(x))的期望值. 其中q是变量代换之后的函数. 然后反向转换CNF来从q(x)中进行采样.

该任务中, 我们使用64个隐藏单元的CNF和64层的NF来进行对比. 图5展示了最终的训练结果. 从最初的高斯分布, 到最终学到的分布, 每一个图代表时间t的某一步. 有趣的是: 为了拟合两个圆圈, CNF把planar 流 进行了旋转, 这样粒子会均分到两个圆中. 跟 CNF的平滑可解释相对的是, NF模型比较反直觉, 并且很难拟合双月牙的概率分布(见图5.b)

[图片上传失败...(image-687aaa-1616472365302)]

5 生成式隐方程时间序列模型

将神经网络应用于不规则采样的数据,如医疗记录、网络流量或神经尖峰数据是困难的。 通常,观测被放入固定持续时间的桶中,隐方程(变量?原文是dynamic)以同样的方式进行离散。如果存在数据缺失或隐变量定义不当的情况, 问题就比较困难. 数据缺失可以用数据填充和生成时间序列模型来进行标记. 还有一种方式是给RNN的输入加时间戳信息.

我们提出了一种连续时间,生成的方法来建模时间序列。我们的模型用一个隐轨迹来表示每个时间序列。每个轨迹都是由一个局部初始状态zt0和跨所有时间序列共享的全局隐方程组来确定。给定观测时间t0、t1、……tN和初始状态zt0,ODE求解算器产生zt1,…ztN,描述每个观测的潜在状态。我们通过一个采样程序正式地定义了这个生成模型:

image.png

函数f是一个时间无关的函数,在当前时间步长取z并输出梯度:

image.png

我们用神经网络来参数化这个方程. 因为f是时间无关的, 给定隐状态z(t), 整个隐轨迹就是唯一确定的. 推断隐轨迹可以让我们在时间上任意向前或后退做出预测

image.png

训练与预测

我们可以用观测的序列将这个潜变量模型训练为变分自动编码器. 我们的判别模型RNN倒序的接收时间序列数据, 输出q φ (z 0 |x 1 ,x 2 ,...,x N ). 详见附录E. 使用ODE来做生成模型, 我们就能在已知时间序列的情况下, 在任意时间点做出预测.

泊松过程似然

观测本身就给出了一些隐状态的信息, 比如说: 得病的人更倾向于做药物测试. 事件发生率可以用隐方程来进行参数化:

image.png

给定这个概率函数,非均匀泊松过程给出了区间[tstart,tend]中独立观测的可能性:

image.png

我们可以使用另一个神经网络来参数化λ(·)。因此,我们可以调用一次ODE求解器就评估出隐轨迹和泊松过程概率值。图7为该模型在数据集上学习到的事件发生率。


image.png

观测时间上的泊松过程似然可以与数据似然相结合,共同模拟所有观测和时间。

5.1 事件序列隐ODE试验

我们研究了隐ODE模型的拟合和推断时间序列的能力。该判别网络是一个有25个隐藏单元的RNN。我们使用一个四维的隐空间。我们用一个具有20个隐藏单元的单隐藏层网络来参数化函数f。解码器是一个神经网络, 只有一个隐藏层, 20个隐藏单元, 用于计算p(x t i |z t i )。我们的基线是一个有25个隐藏单元的RNN,用最小化负高斯对数似然为目标函数训练。我们训练了这个RNN的第二个版本,其输入与下一个观测的时间差连接,以帮助RNN进行不规则的观测。

双向螺旋数据集

我们生成了一个1000个二维螺旋的数据集,每个螺旋从一个不同的点开始,在100个相同间隔的时间步长采样。 数据集包含两种类型的螺旋:一半是顺时针方向,另一半是逆时针方向。 为了模拟真实情况,我们在观测中加入高斯噪声。

具有不规则时间点的时间序列

为了生成不规则的时间戳,我们不替换的从每个轨迹随机采样 (n={30,50,100}). 训练数据之外, 我们展示了100个时间点的预测均方根误差(RMSE)。 表2显示,隐ODE预测时的RMSE明显较低.

image.png

图8展示了用下采样的30个点来拟合螺旋的结果.

[图片上传失败...(image-53050-1616472422476)]

隐ODE的重构是通过对潜在轨迹的后验采样并将其解码为数据空间得到的. 附录F展示了更多不同数据点的情况. 我们发现, 不管多少个点的下采样, 不管有没有高斯噪声, 重建和推断都和真实情况一致.

隐空间推断

图8c展示了隐轨迹投影到隐空间前2个维度的结果. 这是两个轨迹群, 一个顺时针一个逆时针. 图9展示了: 初始状态隐轨迹方程为顺时针, 而后转变为逆时针, 这一转变过程是非常连续的.


image.png

6 应用范围与限制.

Mini-Batch

Mini-Batch的使用不如标准神经网络那么直观。我们仍然可以通过将每个batch的状态连接在一起,创建维度D×K的ODE方程组,通过ODE求解器来计算。In some cases, controlling

error on all batch elements together might require evaluating the combined system K times more

often than if each system was solved individually(不太懂什么意思)。不过,在实践中使用Mini-Batch时,计算量并没有大幅增加.

唯一性

什么情况下连续方程有唯一解? 皮卡存在定理限定了, 当微分方程Lipschitz连续并且z在t上连续时, 初值问题的解存在且唯一. 这就对我们使用的神经网络有所限制, 模型的权重有限, 且不能使用非Lipschitz连续的激活函数, 比如tanh或者relu.

设置计算精度

模型允许用户在计算精度和速度之间做trade-off, 需要用户在训练的前向和反向中设置误差限. 对于序列模型, 默认值为1.5e-8. 在分类和概率密度拟合问题中, 不降低模型性能的情况下, 默认值可设置为1e-3和1e-5.

重建前向轨迹

如果重建的轨迹偏离了原轨迹,则通过向后运行的方程来重建状态轨迹会带来额外的数值误差。这个问题可以通过checkpoint来解决:将z的中间值存储在前向过程中,并通过从这些点重新积分来重建精确的前向轨迹。不过在实际计算中这不是一个问题,多层CNF的反向可以恢复到初始状态.

7 相关工作

8 结语

我们探索了黑盒ODE求解器作为模型的一部分, 并用它开发了新模型可以用于时间序列问题, 监督学习问题, 概率密度估计问题. 这些模型可以自适应的进行估值计算, 并且允许用户显式的在计算速度和精度之间做取舍. 最终, 我们提出了连续版本的变量代换模型, 命名为CNF, 该模型的层可以扩展到比较大的尺度.

9 注:

我没有对附录和参考文献做翻译, 这部分大家请下载论文原文查看: https://arxiv.org/pdf/1806.07366

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,332评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,508评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,812评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,607评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,728评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,919评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,071评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,802评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,256评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,576评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,712评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,389评论 4 332
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,032评论 3 316
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,798评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,026评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,473评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,606评论 2 350

推荐阅读更多精彩内容