在第4章和第7章,分别使用线性回归和多层感知机来对鸢尾花进行分类。在训练过程中,都使用了
jax.grad()
对损失函数进行梯度下降算法计算,从而更新参数。jax.grad()是自动微分器,如下图,

一直以来,自动微分在神经网络框架背后默默地运行着,为神经网络的提供了底层支撑。

导数、微分、梯度
简单了解一下三者的概念。
导数
函数在某一点出变化的快慢,是变化率。从几何角度来说,表示曲线上某点(自变量变化量趋近于0)处切线的斜率。可以用物体受力运动的距离和时间的关系来理解,导数就是速度变化率——瞬时速度。速度是关于位移的一阶导数,加速度是关于速度的一介导数,加速度是关于位移的二阶导数。
微分
指函数在某一点处(趋近无穷小)的变化量,是一个数值。对于多元函数的全微分,就是各个自变量微分的和,即各个分变化量的和。微分本质是一个微小的线性变化量,是用一个线性函数作为原函数变化的逼近(或者叫近似),微分就是为了方便用线性函数分析原函数,毕竟线性函数容易理解和分析问题。
导数和微分就相当于速度和路程的关系。
梯度
梯度是一个向量,既有方向又有大小。方向是函数在某点处变化率最大的方向,此时,梯度的大小就是变化率,根据上面导数的定义可知,梯度的大小就是该点处的导数。一个典型例子是道士下山。假如道士在山顶,要赶在大雨前下到山脚。不考虑山路行走难度,怎么找到最快的路径?找最陡峭的下山山路。这就要每走几步就找一下最陡峭路径。
什么是微分器
在数学和计算机代数学中,自动微分也被称为微分算法或者数值微分。他是一种数值计算的方式,用来计算因变量对自变量的导数,所以又叫自动求导。自动微分器是一种计算机程序,与手动计算微分的“分析法”不一样。
自动微分器的基于一个事实,即一个计算机程序,无论多么复杂,都在执行加、减、乘、除这些基本算术运算,以及指数、对数、三角函数等这类初等函数运算。通过将链式求导法则应用到这些运算上,则能以任意精度自动地计算导数,而且最多只比原始程序多一个常数级的运算。
注意,由于微分器是一种数值计算方式,所以仅仅能处理数值相关的计算。
例如,对下面公式求解微分,

将上述公式转换成计算图形式,如下所示,

图中每个圆圈表示操作产生的中间结果,下标顺序并表示计算顺序。根据计算图一步步来计算函数值,如下图所示。其中,左侧表示数值计算过程,右侧表示梯度计算过程后。

JAX中的自动微分
举一个例子讲解JAX的自动微分,

设x = 2,
根据高等数学里求导方法可知,

指数函数求导公式如下,

则,



代入x = 2,




如果使用JAX完成上述计算,代码如下所示,
import jax
def function(x):
"""
f(x) = x³ + 2x² + 3x + 1
"""
return x ** 3 + 2 * x ** 2 + 3 * x + 1
def test():
x = 2.0
y = function(x)
print(f"function({x}) = {y}")
derivative_function = jax.grad(function)
y = derivative_function(x)
print(f"derivative_function({x}) = {y}")
second_derivative_function = jax.grad(derivative_function)
y = second_derivative_function(x)
print(f"second_derivative_function({x}) = {y}")
third_derivative_function = jax.grad(second_derivative_function)
y = third_derivative_function(x)
print(f"third_derivative_function({x}) = {y}")
if __name__ == "__main__":
test()
打印输出如下所示,
function(2.0) = 23.0
derivative_function(2.0) = 23.0
second_derivative_function(2.0) = 16.0
third_derivative_function(2.0) = 6.0
打印输出和上述手动计算结果一致。
注意,jax.grad要求必须使用float类型,而不能是int类型。jax.grad是jax的微分程序,对结果自动求导。可以看到,jax.grad是一个求导借口,其输入/输出都是函数,因此借助于jax.grad方便地去做高阶求导。
一般在神经网络中求导并不是单个数,而是一个序列(矩阵)多个数字共同求导。对于这个问题,解决方法如下,
import jax
def function(x):
"""
f(x) = x³ + 2x² + 3x + 1
"""
return x ** 3 + 2 * x ** 2 + 3 * x + 1
def test():
# Sum the function(x), then grad
derivative_function = jax.grad(lambda x: jax.numpy.sum(function(x)))
# [1., 2., 3., 4., 5.]
x = jax.numpy.linspace(1, 5, 5)
print(f"x = {x}")
y = derivative_function(x)
print(f"derivative_function({x}) = {y}")
if __name__ == "__main__":
test()
打印输出如下所示,
x = [1. 2. 3. 4. 5.]
derivative_function([1. 2. 3. 4. 5.]) = [10. 23. 42. 67. 98.]
JAX的求导函数先对函数值进行求和计算,之后根据求和结果对求和值进行求导,然后根据所占的权重和位置分解到每一个数值。
从更底层来说,jax.grad使用的是反向自动微分模式,
jax.vjp()
Jax.vjp()根据原始函数function,输入x计算得出函数结果y并生成微分用的线性方程(参考上面微分说明部分)。jax.grad默认采用反向自动微分,从底层调用vjp()。
结论
本章介绍了微分器的概念,粗略介绍了导数、微分以及梯度的概念,使用jax.grad进行自动微分的计算。希望不要劝退。