JAX简单来说就是支持GPU计算(比如CUDA)、自动求导(Autograd)的NumPy。在Python,进行数值计算离不开NumPy这个基础数值运算库。但NumPy不支持GPU或者其他硬件加速器,也不对Backwards Propagation的内置支持,再加上Python本身的速度限制,很少会在生产环境下直接使用NumPy训练或部署深度学习模型。这也是为什么会出现PyTorch、TensorFlow、Caffe等深度学习框架的原因。
JAX是机器学习框架领域的新生力量,其建立在加速线性代数XLA(Accelerated Linear Algebra)。具有更快的高阶求导,对GPU和TPU更好的支持。
通过XLA,JAX能对原生Python和NumPy进行自动微分。
XLA与JAX
JAX使用XLA在GPU和TPU上进行编译和运行NumPy程序,XLA是将JAX转化为加速器支持操作的中坚力量。
XLA作为深度学习编译器,其长期以来作为Google在深度学习领域的一个重要特性被开发,特别是hi作为TensorFlow 2.0背后支持之一,XLA也从实验特性变成了默认打开的特性。
XLA如何运行
XLA是一种针对特定领域的线性代数编译器,比如能够加快TensorFlow模型的运行速度,而且完全不需要更改源码。一般编译器都服务于如下目的,
- 提高代码执行速度。
- 优化存储使用。
XLA也不例外。XLA功能主要体现在以下几个方面,
- 融合可组合算子从而提高运算速度。
XLA通过对TensorFlow运行时计算图进行分析,将多个低级算子进行融合,从而生成高效的机器码。如下图,计算图中许多算子都是逐元素(element-wise)地计算,所以融合(Fuse)到一个element-wise的循环计算kernel中,从而减少了kernel加载,提高计算性能。
例如,上图中,bias向量里某单个元素与从matmul出来的结果某单个元素进行加法计算,对于其后的ReLU函数,这个加法计算结果是一个可与0进行比较的单元素。比较的结果使用输入值的幂进行指数运算和除法运算,从而产生softmax的结果。
将多个kernel融合为一个kernel,因此减少了加载kernel的时间消耗。
因此,XLA对模型的性能提升主要是将多个连续的element-wise算子融合为一个算子。
- 减少内存占用
在上面elemnt-wise计算整个过程,不需要为matmul、add、和ReLU等算子使用中间数组而开辟内存空间。通过将这些算子融合,可以减少申请这些酸子间的中间结果所占用的内存。
- 减少模型可执行文件大小
对移动设备场景,XLA可以减少模型的可执行文件的大小。通过AOT(Ahead of Time)编译将整个计算图生成轻量级的机器码,这些机器码实现了计算图中的各个操作。
在模型运行时,不需要在一个完整的运行环境,因为计算图实际执行时的操作被转换编译为设备代码。
- 方便支持不同硬件后端
当传统的深度学习应用支持一种新的设备时(比如不同的xPU指令集),需要将所有的kernels(OPs)再重新实现一遍,需要巨大的工作量。而通过XLA则需要很小的工作量,因为XLA的的算子都是操作原语(低级算子,primitive or atomic action,不可再分、不可中断的操作),数量少且容易实现,XLA会自动地将复杂的算子拆解为primitive算子。
可以说,XLA在TensorFlow 2.0上成功的实践,移植到到JAX,让JAX站在了成功的肩膀上。
XLA如何工作
XLA的输入语言是HLO IR(High Level Optimizer Intermediate Representation)。XLA将HLO描述的计算图(计算流程)编译成真对各种特定后端的机器指令。
整个流程如下图,
首先XLA对输入的HLO计算图进行目标设备无关的优化和分析,如CSE(Common Subexpression Elimination,公共子表达式删除,通过删除运算中等价子表达式的一种代码优化方法)、算子融合,运行时的内存分配分析。输出为优化后的HLO计算图。
然后将HLO计算图发送到后端(Backend),后端结合特定的硬件属性对HLO计算图进行进一步的HLO级优化,例如对某些操作或其组合进行模式匹配从而优化计算库调用。
最后,后端将HLO IR转化为LLVM IR,LLVM再进行低级优化并生成机器码。
JAX一般特性
JAX可以自动微分原生Python和NumPy函数。可以通过Python的大部分功能(包括循环、if、递归和闭包)进行微分。支持反向模式和正向模式微分,并且两者可以以任意顺序组成。
JAX的新功能是使用XLA再诸如GPU和TPU的加速器上编译和运行NumPy代码。默认情况下,编译器是后台运行的,而库调用将得到及时的编译和运行。但JAX允许使用单功能API将Python函数编译为XLA优化的内核。编译和自动微分可以任意组合,因此无需离开Python即可表达复杂的算法并获得最佳性能。
利用JIT加速程序运行
虽然精心编写的NumPy代码运行起来效率很高,但对于现代机器学习来说,还是希望这些代码运行的尽可能快。这一般通过在GPU或TPU(Tensor Processing Unit,张量处理单元)等不同的的“加速器”上运行代码来实现。
JAX提供了JIT(Just In Time即时)编译器,采用标准的Python、NumPy函数,经编译后可以在加速器上高效运行。编译函数还可以避免Python解释器的开销。总的来说,jax.jit可以显著加速代码运行,且基本上没有额外的编码开销,需要做的就是使用JAX编译函数。在使用jax.jit时,即使是微小的神经网络也可以实现相当惊人的加速效果。
下面演示使用jit加速的程序。
import time
import jax
def selu(inputs, alpha = 1.67, theta = 1.05):
return jax.numpy.where(inputs > 0, inputs, alpha * jax.numpy.exp(inputs) - alpha) * theta
@jax.jit
def selu1(inputs, alpha = 1.67, theta = 1.05):
return jax.numpy.where(inputs > 0, inputs, alpha * jax.numpy.exp(inputs) - alpha) * theta
def test():
key = jax.random.PRNGKey(15)
inputs = jax.random.normal(key = key, shape = (1000000,))
start = time.time()
selu(inputs = inputs)
end = time.time()
print(f"Time consumed by selu: %.2f seconds" % (end - start))
jit_selu = jax.jit(selu)
start = time.time()
jit_selu(inputs)
end = time.time()
print(f"Time consumed by jit_selu: %.2f seconds" % (end - start))
start = time.time()
selu1(inputs)
end = time.time()
print("Time consumed by selu1: %.2f seconds" % (end - start))
if __name__ == "__main__":
test()
第一个selu时Python的原生函数,第二个selu时通过jax.jit包装(wrap)后的函数,第三个是通过@jax.jit修饰的selu1函数。运行结果打印输出如下,
Time consumed by selu: 0.43 seconds
Time consumed by jit_selu: 0.10 seconds
Time consumed by selu1: 0.10 seconds
可见,相同的一段代码在不同的运行机制下速度有极大的不同。由于JAX充分利用了jit的特性,通过jit包装或修饰,使得函数在第一次调用后就被jit compile缓存。
自动求导器
除了评估数值函数,还可以使用自动求导。在JAX中,使用jax.grad函数来计算导数。jax.grad接受一个函数并返回一个新函数,该函数计算原始函数的梯度。要使用梯度下降,可以根据神经网络的参数计算损失函数的梯度,即使用jax.grad(loss_function)来计算损失函数的梯度。
下面演示对一个计算函数使用jax.grad进行求导。
import jax
def cumulate_exponents(inputs):
# Sigmoid Activation Funciton
# f(x) = 1.0/(1.0 + e⁻ˣ)
sigmoids = 1. / (1. + jax.numpy.exp(-inputs))
#
# Σsigmoids
return jax.numpy.sum(sigmoids)
def test():
# f'(x) = e⁻ˣ / (1 + e⁻ˣ)² = (1 - f(x)) x f(x)
# derivative_cumulate_exponents = f'(x)
# [0.25, ]
derivative_cumulate_exponents = jax.grad(cumulate_exponents)
# inputs = [0. 1. 2.]
# sigmoids = [0.5, 0.731058578630005, 0.880797077977882]
inputs = jax.numpy.arange(3.)
# Σsigmoids = Σ[0.5, 0.731058578630005, 0.880797077977882] = 2.1118555
outputs = cumulate_exponents(inputs)
derivative_outputs = derivative_cumulate_exponents(inputs)
print("inputs =", inputs, ", outputs =", outputs, ", derivative_outputs =", derivative_outputs)
if __name__ == "__main__":
test()
首先生成一个数值序列,对数值公式(其实是sigmoid激活函数)求和,事后使用自动求导器进行求导。注意,原函数求和后输出一个值,经过了求导,按照权重,又分配给3个数值。
运行结果打印输出如下,
inputs = [0. 1. 2.] , outputs = 2.1118555 , derivative_outputs = [0.25 0.19661197 0.10499357]
注意,按照jax的要求,jax.numpy.arange需要输入浮点型。
下面演示共同使用jax.jit函数和jax.grad函数的方法。
import jax
import time
def cumulate_exponents1(inputs):
return jax.numpy.sum(1. / (1. + jax.numpy.exp(-inputs)))
def cumulate_exponents2(inputs):
return jax.numpy.sum(1. / (1. + jax.numpy.exp(-inputs)))
def cumulate_exponents3(inputs):
return jax.numpy.sum(1. / (1. + jax.numpy.exp(-inputs)))
@jax.jit
def cumulate_exponents_jit(inputs):
return jax.numpy.sum(1. / (1. + jax.numpy.exp(-inputs)))
def test():
inputs = jax.numpy.arange(1024.)
start = time.time()
cumulate_exponents_jit_derivative = jax.grad(cumulate_exponents_jit)
print("cumulate_exponents_jit_derivative(inputs) = ", cumulate_exponents_jit_derivative(inputs))
end = time.time()
print("Time consumed: %.2f" % (end - start))
print("------------------------------")
start = time.time()
cumulate_exponents1_derivative = jax.grad(cumulate_exponents1)
print("cumulate_exponents1_derivative(inputs) = ", cumulate_exponents1_derivative(inputs))
end = time.time()
print("Time consumed: %.2f" % (end - start))
print("------------------------------")
start = time.time()
cumulate_exponents2_derivative = jax.grad(cumulate_exponents2)
cumulate_exponents2_derivative_jit = jax.jit(cumulate_exponents2_derivative)
print("cumulate_exponents2_derivative_jit(inputs) = ", cumulate_exponents2_derivative_jit(inputs))
end = time.time()
print("Time consumed: %.2f" % (end - start))
if __name__ == "__main__":
test()
运行结果打印输出如下,
cumulate_exponents_jit_derivative(inputs) = [0.25 0.19661197 0.10499357 ... 0. 0. 0. ]
Time consumed: 0.07
------------------------------
cumulate_exponents1_derivative(inputs) = [0.25 0.19661197 0.10499357 ... 0. 0. 0. ]
Time consumed: 0.19
------------------------------
cumulate_exponents2_derivative_jit(inputs) = [0.25 0.19661197 0.10499357 ... 0. 0. 0. ]
Time consumed: 0.03
可见,使用jit包装函数极大地提高了运行速度,而且当包装顺序调整后,花费时间只有原始函数的1/6,提升极大。
vmap函数实现自动向量化映射
JAX的API里还有另外一个重要转换函数vmap,即向量化映射。字面意思是创建函数,映射到以参数为轴函数。该函数具有沿着参数轴映射函数的熟悉语义(Familiar Semantics),但是不是将循环保留在外部,而是将循环推入函数的原始操作中以提高性能。当与jax.jit结合时,更能提高计算速度。
在实践中,当训练现代机器学习模型是,可以执行“小批量”的梯度下降,在梯度下降的每个步骤中,对一小批示例中的损失梯度求平均值;当示例的数据适中时,这样做完全没有问题,但当数据过多时,这样做会使得JAX在计算时消耗大量的时间。
解决的办法就是JAX额外提供了jax.vmap,该函数可以对函数进行“向量化”处理,也就是说该函数允许在输入的某个轴上并行计算函数的输出。简单来说,就是可应用jax.vmap函数向量化并立即获得损失函数梯度的版本,该版本适用于小批量示例。
下面用代码说明jax.vmap方法。
import jax
import time
def cumulate_exponents(inputs):
return jax.numpy.sum(1. / (1. + jax.numpy.exp(-inputs)))
def test():
start = time.time()
inputs = jax.numpy.arange(1024000.)
cumulate_exponents1_derivative = jax.grad(cumulate_exponents)
cumulate_exponents1_derivative(inputs)
end = time.time()
print("Time consumed: %.2f" % (end - start), "when function executing cumulate_exponents1_derivative")
start = time.time()
inputs = jax.numpy.arange(1024000.)
cumulate_exponents2_derivative = jax.vmap(jax.grad(cumulate_exponents))
cumulate_exponents2_derivative(inputs)
end = time.time()
print("Time consumed: %.2f" % (end - start), "when function executing cumulate_exponents2_derivative")
start = time.time()
inputs = jax.numpy.arange(1024000.)
cumulate_exponents3_derivative = jax.jit(jax.vmap(jax.grad(cumulate_exponents)))
cumulate_exponents3_derivative(inputs)
end = time.time()
print("Time consumed: %.2f" % (end - start), "when function executing cumulate_exponents3_derivative")
if __name__ == "__main__":
test()
运行结果打印输出如下,
Time consumed: 0.29 when function executing cumulate_exponents1_derivative
Time consumed: 0.22 when function executing cumulate_exponents2_derivative
Time consumed: 0.04 when function executing cumulate_exponents3_derivative
可以看到,随着加载更多JAX的特性函数,计算时间依次递减。
除了jax.grad、jax.jit以及jax.vmap,在介绍另外两个函数,
- in_axes,是一个元组或整数,该函数高速JAX函数参数应该对哪些轴并行化。元组应该与vmap函数的参数数量相同,或者只有一个参数是为整数。比如,使用 (None, 0, 0)是指不在第一个参数(argument)上并行化,并在第二个参数和第三个参数的第一个(索引0)维度上并行化。
- out_axes,与in_axes类似。指定了函数输出的哪些轴并行化。
结论
本章探讨了JAX和XLA的特性,以及XLA运行机制。另外,通过代码说明jax.grad、jax.jit、jax.vmap的功能和用法。