第8章 【初级劝退】自动微分器

在第4章和第7章,分别使用线性回归和多层感知机来对鸢尾花进行分类。在训练过程中,都使用了

jax.grad()

对损失函数进行梯度下降算法计算,从而更新参数。jax.grad()是自动微分器,如下图,

图1 自动微分器

一直以来,自动微分在神经网络框架背后默默地运行着,为神经网络的提供了底层支撑。

图2 自动微分在神经网络框架中

导数、微分、梯度

简单了解一下三者的概念。

导数

函数在某一点出变化的快慢,是变化率。从几何角度来说,表示曲线上某点(自变量变化量趋近于0)处切线的斜率。可以用物体受力运动的距离和时间的关系来理解,导数就是速度变化率——瞬时速度。速度是关于位移的一阶导数,加速度是关于速度的一介导数,加速度是关于位移的二阶导数。

微分

指函数在某一点处(趋近无穷小)的变化量,是一个数值。对于多元函数的全微分,就是各个自变量微分的和,即各个分变化量的和。微分本质是一个微小的线性变化量,是用一个线性函数作为原函数变化的逼近(或者叫近似),微分就是为了方便用线性函数分析原函数,毕竟线性函数容易理解和分析问题。

导数和微分就相当于速度和路程的关系。

梯度

梯度是一个向量,既有方向又有大小。方向是函数在某点处变化率最大的方向,此时,梯度的大小就是变化率,根据上面导数的定义可知,梯度的大小就是该点处的导数。一个典型例子是道士下山。假如道士在山顶,要赶在大雨前下到山脚。不考虑山路行走难度,怎么找到最快的路径?找最陡峭的下山山路。这就要每走几步就找一下最陡峭路径。

什么是微分器

在数学和计算机代数学中,自动微分也被称为微分算法或者数值微分。他是一种数值计算的方式,用来计算因变量对自变量的导数,所以又叫自动求导。自动微分器是一种计算机程序,与手动计算微分的“分析法”不一样。

自动微分器的基于一个事实,即一个计算机程序,无论多么复杂,都在执行加、减、乘、除这些基本算术运算,以及指数、对数、三角函数等这类初等函数运算。通过将链式求导法则应用到这些运算上,则能以任意精度自动地计算导数,而且最多只比原始程序多一个常数级的运算。

注意,由于微分器是一种数值计算方式,所以仅仅能处理数值相关的计算。

例如,对下面公式求解微分,

将上述公式转换成计算图形式,如下所示,

图3 公式转换成计算图

图中每个圆圈表示操作产生的中间结果,下标顺序并表示计算顺序。根据计算图一步步来计算函数值,如下图所示。其中,左侧表示数值计算过程,右侧表示梯度计算过程后。

图4 数值及梯度计算过程

JAX中的自动微分

举一个例子讲解JAX的自动微分,

设x = 2,

根据高等数学里求导方法可知,

指数函数求导公式如下,

则,

代入x = 2,

如果使用JAX完成上述计算,代码如下所示,


import jax

def function(x):

    """
    f(x) = x³ + 2x² + 3x + 1
    """
    
    return x ** 3 + 2 * x ** 2 + 3 * x + 1
    
def test():

    x = 2.0
    y = function(x)
    
    print(f"function({x}) = {y}")
    
    derivative_function = jax.grad(function)
    y = derivative_function(x)
    
    print(f"derivative_function({x}) = {y}")
    
    second_derivative_function = jax.grad(derivative_function)
    y = second_derivative_function(x)

    print(f"second_derivative_function({x}) = {y}")
    
    third_derivative_function = jax.grad(second_derivative_function)
    y = third_derivative_function(x)
    
    print(f"third_derivative_function({x}) = {y}")
    
if __name__ == "__main__":

    test()

打印输出如下所示,


function(2.0) = 23.0
derivative_function(2.0) = 23.0
second_derivative_function(2.0) = 16.0
third_derivative_function(2.0) = 6.0

打印输出和上述手动计算结果一致。

注意,jax.grad要求必须使用float类型,而不能是int类型。jax.grad是jax的微分程序,对结果自动求导。可以看到,jax.grad是一个求导借口,其输入/输出都是函数,因此借助于jax.grad方便地去做高阶求导。

一般在神经网络中求导并不是单个数,而是一个序列(矩阵)多个数字共同求导。对于这个问题,解决方法如下,


import jax

def function(x):

    """
    f(x) = x³ + 2x² + 3x + 1
    """

    return x ** 3 + 2 * x ** 2 + 3 * x + 1
    
def test():

    # Sum the function(x), then grad
    derivative_function = jax.grad(lambda x: jax.numpy.sum(function(x)))
    
    # [1., 2., 3., 4., 5.]
    x = jax.numpy.linspace(1, 5, 5)
    print(f"x = {x}")
    
    y = derivative_function(x)
    print(f"derivative_function({x}) = {y}")
    
if __name__ == "__main__":

    test()


打印输出如下所示,

x = [1. 2. 3. 4. 5.]
derivative_function([1. 2. 3. 4. 5.]) = [10. 23. 42. 67. 98.]

JAX的求导函数先对函数值进行求和计算,之后根据求和结果对求和值进行求导,然后根据所占的权重和位置分解到每一个数值。

从更底层来说,jax.grad使用的是反向自动微分模式,

jax.vjp()

Jax.vjp()根据原始函数function,输入x计算得出函数结果y并生成微分用的线性方程(参考上面微分说明部分)。jax.grad默认采用反向自动微分,从底层调用vjp()。

结论

本章介绍了微分器的概念,粗略介绍了导数、微分以及梯度的概念,使用jax.grad进行自动微分的计算。希望不要劝退。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容