第32章 JAX函数自动打包器

前第22章我们提到过vmap函数实现自动向量化映射,对数据进行打包和计算,本章将深入介绍vmap的一些细节。

用剥洋葱的方式处理数组

传统的批处理方式很一般都是对数据的批处理,之后再使用函数从外到内一层层地重新计算。首先我们实现如下代码。


import jax

def convolve(inputs, weights):
    
    outputs = []
    stride_size = 2
    
    # i in [0, 1, 2]
    # len(inputs) = 5
    for i in range(len(inputs) - stride_size):
        
        # inputs[0, 2], inputs[1, 3], inputs[2, 4]
        slices = inputs[i: i + stride_size + 1]
        dotted = jax.numpy.dot(slices, weights)
        outputs.append(dotted)
        
    return jax.numpy.array(outputs)

def test():
    
    # [0, 1, 2, 3, 4]
    inputs = jax.numpy.arange(5)
    weights = jax.numpy.array([2., 3., 4.])
    
    outputs = convolve(inputs, weights)
    
    print("outputs = ", outputs)
    
if __name__ == "__main__":
    
    test()

运行结果打印输出如下,


outputs =  [11. 20. 29.]

这是一个非常简单的矩阵乘法运算。假设此时我们需要对其进行修改,将原来的inputs和weights分别进行二次堆叠,即发生如下变化,

inputs = jax.numpy.stack([inputs, inputs])
weights = jax.numpy.stack([weights, weights])

再添加一个新的计算函数def batched_convolve(inputs, weights),代码如下所示,


import jax

def convolve(inputs, weights):
    
    outputs = []
    stride_size = 2
    
    # i in [0, 1, 2]
    # len(inputs) = 5
    for i in range(len(inputs) - stride_size):
        
        # inputs[0: 3], inputs[1: 4], inputs[2, 5]
        slices = inputs[i: i + stride_size + 1]
        dotted = jax.numpy.dot(slices, weights)
        outputs.append(dotted)
        
    return jax.numpy.array(outputs)

def batched_convolve(inputs, weights):
    
    outputs = []
    
    for i in range(inputs.shape[0]):
        
        outputs.append(convolve(inputs[i], weights[i]))
    
    outputs = jax.numpy.stack(outputs)
    
    return outputs

def test():
    
    # [0, 1, 2, 3, 4]
    inputs = jax.numpy.arange(5)
    weights = jax.numpy.array([2., 3., 4.])
    
    # outputs = convolve(inputs, weights)
    
    # print(f"inputs = {inputs}, weights = {weights}, outputs = ", outputs)
    
    inputs = jax.numpy.stack([inputs, inputs])
    weights = jax.numpy.stack([weights, weights])
    
    outputs = batched_convolve(inputs, weights)
    
    print(f"inputs = {inputs}, weights = {weights}, outputs = ", outputs)
    
if __name__ == "__main__":
    
    test()

新的函数调用了原来的函数,并将其作为自身的一部分进行计算和输出。运行结果打印输出如下,


inputs = [[0 1 2 3 4]
 [0 1 2 3 4]], weights = [[2. 3. 4.]
 [2. 3. 4.]], outputs =  [[11. 20. 29.]
 [11. 20. 29.]]

可以看到,这是完成输出的运算结果。结果是正确的,但是从效率方面来说,这个函数并不高。为了有效地对计算进行批次处理,通常需要手动重写函数,以确保函数是以矢量化形式完成的。这并不难实现,但设计更该函数如何处理索引(index)、轴(axes)和输入等内容。手动实现批处理的程序如下所示,


def vectorized_convolve_v1(inputs, weights, stride_size = 2):
        
    outputs = []
        
    for i in range(0, inputs.shape[-1] - stride_size):
            
        slices = inputs[:, i: i + stride_size + 1]
        dotted = slices @ weights.T
        outputs.append(dotted)
            
    jax.numpy.stack(outputs, axis = 1)
        
    return outputs

def test():
    
    # [0, 1, 2, 3, 4]
    inputs = jax.numpy.arange(5)
    weights = jax.numpy.array([2., 3., 4.])
    
    # outputs = convolve(inputs, weights)
    
    # print(f"inputs = {inputs}, weights = {weights}, outputs = ", outputs)
    
    inputs = jax.numpy.stack([inputs, inputs])
    weights = jax.numpy.stack([weights, weights])
    
    '''
    
    outputs = batched_convolve(inputs, weights)
        
    print(f"inputs = {inputs}\n, weights = {weights},\n outputs = ", outputs)
    
    print("------------------------------------------")
    
    outputs = vectorized_convolve(inputs, weights)
        
    print(f"inputs = {inputs}, \nweights = {weights},\n outputs = ", outputs)
    
    print("------------------------------------------")
    
    '''
    
    outputs = vectorized_convolve_v1(inputs, weights)
        
    print(f"inputs = {inputs}, \nweights = {weights},\n outputs = ", outputs)
    
if __name__ == "__main__":
    
    test()

运行结果打印输出如下,


inputs = [[0 1 2 3 4]
 [0 1 2 3 4]], 
weights = [[2. 3. 4.]
 [2. 3. 4.]],
 outputs =  [Array([[11., 11.],
       [11., 11.]], dtype=float32), Array([[20., 20.],
       [20., 20.]], dtype=float32), Array([[29., 29.],
       [29., 29.]], dtype=float32)]

可以看到,无论采用何种方法,都可以完成数据的计算,然而这所有的方法和算法都是基于对数据的打包,仍然是一次次地将数据输入到函数中进行计算。能否改变一下思路,在数据不动的情况下,。打包函数进行计算呢?

自动向量化函数vmap

在JAX中,jax.vmap转换被设计成为在参数轴上自动生成一个函数的自动化打包器。使用上一节中提到的convolve函数并计算其处理多维数据的结果,代码如下所示,


import jax

def convolve(inputs, weights, stride_size = 2):
    
    outputs = []
    
    for i in range(inputs.shape[-1] - stride_size):
        
        slices = inputs[i: i + stride_size + 1]
        dotted = jax.numpy.dot(slices, weights)
        
        outputs.append(dotted)
        
    outputs = jax.numpy.array(outputs)
    
    return outputs

def test():
    
    inputs = jax.numpy.arange(5)
    # inputs = jax.numpy.stack([inputs, inputs])
    
    weights = jax.numpy.array([2., 3., 4.])
    # weights = jax.numpy.stack([weights, weights])
    
    outputs = convolve(inputs, weights)
    
    print("outputs = ", outputs)
    
    auto_batch_convolve = jax.vmap(convolve)
    
    inputs = jax.numpy.stack([inputs, inputs])
    weights = jax.numpy.stack([weights, weights])
    
    outputs = auto_batch_convolve(inputs, weights)
    
    print("outputs = ", outputs)
    
if __name__ == "__main__":
    
    test()

运行结果打印输出如下,


outputs =  [11. 20. 29.]
outputs =  [[11. 20. 29.]
 [11. 20. 29.]]

注意,使用jax.vmap进行包装的代码,jax.vmap通过类似一jax.jit的跟踪函数来实现对函数的自动化打包,并在每个输入的开头自动添加批处理轴,不用再使用jax.numpy.stack(outputs, axis = 1)进行堆叠。

如果批处理维度不是第一个,则可以使用in_axes和out_axes参数来指定批处理维度在输入和输出的位置。如果所有输入和输出的批处理相同,或者列表相同,则为正数。代码如下所示,

    auto_batch_convolve_v2 = jax.vmap(convolve, in_axes = 1, out_axes = 1)
    
    inputs = jax.numpy.transpose(inputs)
    weights = jax.numpy.transpose(weights)
    
    outputs = auto_batch_convolve_v2(inputs, weights)
    
    print("outputs = ", outputs)

运行结果打印输出如下,


outputs =  [11. 20. 29.]
outputs =  [[11. 20. 29.]
 [11. 20. 29.]]
outputs =  [[11. 11.]
 [20. 20.]
 [29. 29.]]

此外,还有一种情况,提提供了两个经过批处理后的数据,但是在某些情况下可能只有一个数据被批处理进修正,此时vmpa同样可以对其操作,代码如下所示,


    auto_batch_convolve_v3 = jax.vmap(convolve, in_axes = [0, None])
    weights2 = jax.numpy.stack([weights, weights], axis = 0)
    
    outputs = auto_batch_convolve_v3(inputs, weights2)
    
    print("outputs = ", outputs)

与所有JAX的转换包装函数一样,jax.jit和jax.vmap可以组合使用,这意味着可以使用jit包装的vmap函数,也可以输用vmap包装的jit函数,都可以正常工作。

JAX中高阶导数的处理

计算梯度是现代机器学习方法的重要组成部分。本节讨论一些与现在机器学习相关的自动求导领域的高级功能。虽然在大多数情况下,了解自动求导是如何工作的并不是使用JACX的先决条件,但了解其具体公式可以更加深入地了解JAX的内部运行规律。

由于计算导数的函数本身是可微的,所以JAX的自动求导使计算高阶导数变得容易。因此,高阶导数就像叠加变换一样容易。
f\left( x \right) = 3x^3 - 2x^2 + x - 1

一阶导函数
\frac{d f\left( x \right)}{dx} = 9x^2 - 4x + 1
二阶导函数
\frac{d^2 f\left( x \right) }{dx} = 18x - 4
下面用代码演示如何求一阶导函数和二阶导函数,


import jax

def function(x):
    
    # f(x) = 3x       - 2x       + x - 1
    return 3 * x ** 3 - 2 * x ** 2 + x - 1

def test():
    
    # f'(x) = 6x      - 4x + 1
    dfx_dx = jax.grad(function)
    print("First order derivative = ", dfx_dx(2.))
    
    # f''(x) = 12x- 4
    dfx_dx2 = jax.grad(dfx_dx)
    print("Second order derivative = ", dfx_dx2(2.))
    
if __name__ == "__main__":
    
    test()

运行结果打印输出如下,


First order derivative =  29.0
Second order derivative =  32.0

上面代码里test()函数中第一段是求一阶导函数,第二段是求二阶导函数,其中二阶导函数的求解是在一阶导函数之上再嵌套一层求导函数。

除了使用jax.grad进行函数求导之外,还可以使用JAX提供的两个函数来恶完成求导工作,jax.jacfwd、jax.jacrev以及jax.jacobian,他们使用方法于jax.grad类似。

结论

本章通过对jax.vmap的讨论,学习了通过自动化向量函数jax.vmap对函数进行生成自动化打包器,通过使用jax.vmap可以避免使用手动打包对数据进行堆叠打包处理。另外,讨论了JAX中高阶导数的求解。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容