前第22章我们提到过vmap函数实现自动向量化映射,对数据进行打包和计算,本章将深入介绍vmap的一些细节。
用剥洋葱的方式处理数组
传统的批处理方式很一般都是对数据的批处理,之后再使用函数从外到内一层层地重新计算。首先我们实现如下代码。
import jax
def convolve(inputs, weights):
outputs = []
stride_size = 2
# i in [0, 1, 2]
# len(inputs) = 5
for i in range(len(inputs) - stride_size):
# inputs[0, 2], inputs[1, 3], inputs[2, 4]
slices = inputs[i: i + stride_size + 1]
dotted = jax.numpy.dot(slices, weights)
outputs.append(dotted)
return jax.numpy.array(outputs)
def test():
# [0, 1, 2, 3, 4]
inputs = jax.numpy.arange(5)
weights = jax.numpy.array([2., 3., 4.])
outputs = convolve(inputs, weights)
print("outputs = ", outputs)
if __name__ == "__main__":
test()
运行结果打印输出如下,
outputs = [11. 20. 29.]
这是一个非常简单的矩阵乘法运算。假设此时我们需要对其进行修改,将原来的inputs和weights分别进行二次堆叠,即发生如下变化,
inputs = jax.numpy.stack([inputs, inputs])
weights = jax.numpy.stack([weights, weights])
再添加一个新的计算函数def batched_convolve(inputs, weights),代码如下所示,
import jax
def convolve(inputs, weights):
outputs = []
stride_size = 2
# i in [0, 1, 2]
# len(inputs) = 5
for i in range(len(inputs) - stride_size):
# inputs[0: 3], inputs[1: 4], inputs[2, 5]
slices = inputs[i: i + stride_size + 1]
dotted = jax.numpy.dot(slices, weights)
outputs.append(dotted)
return jax.numpy.array(outputs)
def batched_convolve(inputs, weights):
outputs = []
for i in range(inputs.shape[0]):
outputs.append(convolve(inputs[i], weights[i]))
outputs = jax.numpy.stack(outputs)
return outputs
def test():
# [0, 1, 2, 3, 4]
inputs = jax.numpy.arange(5)
weights = jax.numpy.array([2., 3., 4.])
# outputs = convolve(inputs, weights)
# print(f"inputs = {inputs}, weights = {weights}, outputs = ", outputs)
inputs = jax.numpy.stack([inputs, inputs])
weights = jax.numpy.stack([weights, weights])
outputs = batched_convolve(inputs, weights)
print(f"inputs = {inputs}, weights = {weights}, outputs = ", outputs)
if __name__ == "__main__":
test()
新的函数调用了原来的函数,并将其作为自身的一部分进行计算和输出。运行结果打印输出如下,
inputs = [[0 1 2 3 4]
[0 1 2 3 4]], weights = [[2. 3. 4.]
[2. 3. 4.]], outputs = [[11. 20. 29.]
[11. 20. 29.]]
可以看到,这是完成输出的运算结果。结果是正确的,但是从效率方面来说,这个函数并不高。为了有效地对计算进行批次处理,通常需要手动重写函数,以确保函数是以矢量化形式完成的。这并不难实现,但设计更该函数如何处理索引(index)、轴(axes)和输入等内容。手动实现批处理的程序如下所示,
def vectorized_convolve_v1(inputs, weights, stride_size = 2):
outputs = []
for i in range(0, inputs.shape[-1] - stride_size):
slices = inputs[:, i: i + stride_size + 1]
dotted = slices @ weights.T
outputs.append(dotted)
jax.numpy.stack(outputs, axis = 1)
return outputs
def test():
# [0, 1, 2, 3, 4]
inputs = jax.numpy.arange(5)
weights = jax.numpy.array([2., 3., 4.])
# outputs = convolve(inputs, weights)
# print(f"inputs = {inputs}, weights = {weights}, outputs = ", outputs)
inputs = jax.numpy.stack([inputs, inputs])
weights = jax.numpy.stack([weights, weights])
'''
outputs = batched_convolve(inputs, weights)
print(f"inputs = {inputs}\n, weights = {weights},\n outputs = ", outputs)
print("------------------------------------------")
outputs = vectorized_convolve(inputs, weights)
print(f"inputs = {inputs}, \nweights = {weights},\n outputs = ", outputs)
print("------------------------------------------")
'''
outputs = vectorized_convolve_v1(inputs, weights)
print(f"inputs = {inputs}, \nweights = {weights},\n outputs = ", outputs)
if __name__ == "__main__":
test()
运行结果打印输出如下,
inputs = [[0 1 2 3 4]
[0 1 2 3 4]],
weights = [[2. 3. 4.]
[2. 3. 4.]],
outputs = [Array([[11., 11.],
[11., 11.]], dtype=float32), Array([[20., 20.],
[20., 20.]], dtype=float32), Array([[29., 29.],
[29., 29.]], dtype=float32)]
可以看到,无论采用何种方法,都可以完成数据的计算,然而这所有的方法和算法都是基于对数据的打包,仍然是一次次地将数据输入到函数中进行计算。能否改变一下思路,在数据不动的情况下,。打包函数进行计算呢?
自动向量化函数vmap
在JAX中,jax.vmap转换被设计成为在参数轴上自动生成一个函数的自动化打包器。使用上一节中提到的convolve函数并计算其处理多维数据的结果,代码如下所示,
import jax
def convolve(inputs, weights, stride_size = 2):
outputs = []
for i in range(inputs.shape[-1] - stride_size):
slices = inputs[i: i + stride_size + 1]
dotted = jax.numpy.dot(slices, weights)
outputs.append(dotted)
outputs = jax.numpy.array(outputs)
return outputs
def test():
inputs = jax.numpy.arange(5)
# inputs = jax.numpy.stack([inputs, inputs])
weights = jax.numpy.array([2., 3., 4.])
# weights = jax.numpy.stack([weights, weights])
outputs = convolve(inputs, weights)
print("outputs = ", outputs)
auto_batch_convolve = jax.vmap(convolve)
inputs = jax.numpy.stack([inputs, inputs])
weights = jax.numpy.stack([weights, weights])
outputs = auto_batch_convolve(inputs, weights)
print("outputs = ", outputs)
if __name__ == "__main__":
test()
运行结果打印输出如下,
outputs = [11. 20. 29.]
outputs = [[11. 20. 29.]
[11. 20. 29.]]
注意,使用jax.vmap进行包装的代码,jax.vmap通过类似一jax.jit的跟踪函数来实现对函数的自动化打包,并在每个输入的开头自动添加批处理轴,不用再使用jax.numpy.stack(outputs, axis = 1)进行堆叠。
如果批处理维度不是第一个,则可以使用in_axes和out_axes参数来指定批处理维度在输入和输出的位置。如果所有输入和输出的批处理相同,或者列表相同,则为正数。代码如下所示,
auto_batch_convolve_v2 = jax.vmap(convolve, in_axes = 1, out_axes = 1)
inputs = jax.numpy.transpose(inputs)
weights = jax.numpy.transpose(weights)
outputs = auto_batch_convolve_v2(inputs, weights)
print("outputs = ", outputs)
运行结果打印输出如下,
outputs = [11. 20. 29.]
outputs = [[11. 20. 29.]
[11. 20. 29.]]
outputs = [[11. 11.]
[20. 20.]
[29. 29.]]
此外,还有一种情况,提提供了两个经过批处理后的数据,但是在某些情况下可能只有一个数据被批处理进修正,此时vmpa同样可以对其操作,代码如下所示,
auto_batch_convolve_v3 = jax.vmap(convolve, in_axes = [0, None])
weights2 = jax.numpy.stack([weights, weights], axis = 0)
outputs = auto_batch_convolve_v3(inputs, weights2)
print("outputs = ", outputs)
与所有JAX的转换包装函数一样,jax.jit和jax.vmap可以组合使用,这意味着可以使用jit包装的vmap函数,也可以输用vmap包装的jit函数,都可以正常工作。
JAX中高阶导数的处理
计算梯度是现代机器学习方法的重要组成部分。本节讨论一些与现在机器学习相关的自动求导领域的高级功能。虽然在大多数情况下,了解自动求导是如何工作的并不是使用JACX的先决条件,但了解其具体公式可以更加深入地了解JAX的内部运行规律。
由于计算导数的函数本身是可微的,所以JAX的自动求导使计算高阶导数变得容易。因此,高阶导数就像叠加变换一样容易。
一阶导函数
二阶导函数
下面用代码演示如何求一阶导函数和二阶导函数,
import jax
def function(x):
# f(x) = 3x - 2x + x - 1
return 3 * x ** 3 - 2 * x ** 2 + x - 1
def test():
# f'(x) = 6x - 4x + 1
dfx_dx = jax.grad(function)
print("First order derivative = ", dfx_dx(2.))
# f''(x) = 12x- 4
dfx_dx2 = jax.grad(dfx_dx)
print("Second order derivative = ", dfx_dx2(2.))
if __name__ == "__main__":
test()
运行结果打印输出如下,
First order derivative = 29.0
Second order derivative = 32.0
上面代码里test()函数中第一段是求一阶导函数,第二段是求二阶导函数,其中二阶导函数的求解是在一阶导函数之上再嵌套一层求导函数。
除了使用jax.grad进行函数求导之外,还可以使用JAX提供的两个函数来恶完成求导工作,jax.jacfwd、jax.jacrev以及jax.jacobian,他们使用方法于jax.grad类似。
结论
本章通过对jax.vmap的讨论,学习了通过自动化向量函数jax.vmap对函数进行生成自动化打包器,通过使用jax.vmap可以避免使用手动打包对数据进行堆叠打包处理。另外,讨论了JAX中高阶导数的求解。