第47章 JAX库包

过往的章节,一直在使用各种JAX包(package,类似于某些编程语言的类库),比如jax.numpy、jax.random、jax.nn、jax.lax等,在各种代买实战中使用了大量的原生API,也就是说,直接使用这些JAX官方提供API建立各种深度学习模型。

对于这些包尚未专门介绍,现开辟一章专门介绍JAX中这些包。

JAX包种类

为了更好地管理多个模块源文件,Python提供了包的概念。通过查看Python目录,比如MacOS里,路径$HOME/Library/Python/{version}/lib/python/site-packages下就是安装的各种包。比如和jax相关的包,

jax
jax-{version}.dist-info
jaxlib
jaxlib-{version}.dist-info

从物理上看,每个包对应一个文件夹,文件夹第一个文件一般是init.py,打开文件夹,可见文件夹包含多个子文件夹,可用于包含多个模块源文件;

_src
example_libraries
experimental
image
interpreters
lax
lib
nn
numpy
ops
scipy
tools

从逻辑上看,包的本质依然是模块。JAX中可直接使用的包分为下面几大类,

包名 用途
example_libraries 使用JAX创建libraries的例子,都是一些短小示例代码。一般不会merge的PR。
experimental 实验性质库,一旦成熟会合并到模块。
image 图像操作功能。
experimental 实验性质库,一旦成熟会合并到模块。
image 图像操作功能。
lax lax是一个原语操作库用来增强numpy,像诸如JVP和批量规则,通常就是定义层转换的lax原语。很多原语是对相应XLA操作的浅层封装。尽可能低使用想jax.numpy类库去运算,而避免直接使用jax.lax。
lib 用于桥接JAX Python前端和XLA后端的工具集。
nn 神经网络通用函数库。
numpy 使用jax.lax原语实现的NumPy API库,接近NumPy当不是所有。
ops 函数操作运算符。
scipy 统计分析。
tools 工具类。
random 伪随机数生成工具类。
tree_util 树形结构功能函数集。
profiler 性能分析工具,追踪和衡量时间消耗的工具。
dlpack 深度学习tensorDeviceArray相关。
flatten_util Pytree维度操作有关,比如降维。

可以说,JAX库多种多样,涵盖了深度学习框架的常用函数。下面举例说明。

jax.numpy数值运算

JAX初衷就是为了取代NumPy成为数值运算通用库,当相对于NumPy还有一些区别,

JAX数组不可变(immutable),所以不能像在NumPy随意改变数组元素,这是为配合jax.jit等XLA包装。当然,JAx也提供变通的方法来“更改”数组元素。比如,JAX提供了一个替代的纯索引函数来“更新”数组,而不是array[i] = x。当然,这种所谓更新其实会创建一个新的数组,原数组保持不变。因此,一些在NumPy里返回数组视图的函数,比如numpy.transpose()和numpy.regpe(),在jax.numpy对应的函数将返回数组副本,而不是原数组,尽管在使用jax.jit()编译操作数组时,这些副本通常可以由XLA优化。

jax.numpy提供了大量函数用于数值运算,下图摘抄了官方文档的一部分,


图1 jax.numpy数值运算函数

更多详情可参阅https://jax.readthedocs.io/en/latest/jax.numpy.html

很多Python开发者,比较熟悉NumPy的API,关于数组的差异可以简单介绍一下两者对应的操作,简单来说,由于jax.numy的数组一旦定义赋值无法更改,所以引入array.at[i].set(value)的方法,具体数组处理方式如下表,

JAX NumPy
array.at[index].set(value) array[index] = value
array.at[index].add(value) array[index] += value
array.at[index].multiply(value) array[index] *= value
array.at[index].divide(value) array[index] /= value
array.at[index].power(value) array[index] **= value
array.at[index].min(value) array[index] = minimum(array[index], value)
array.at[index].max(value) array[index] = maximum(array[index], value)
array.at[index].get() value = array[index]

如前面所述,array.at[]表达式不会改变原数组,相反,会返回一个修改过的array副本,也就是说一个在内存中新开辟存储单元的数组。但是,在jax.jit()编译函数时,如array = array.at[index].set(value)这样的表达式肯定会被广泛使用。

与NumPy就地操作不同,如array[index] += value不同,如果多个索引引用同一个位置,则将应用所有更新,而NumPy只会应用最后一个更新,而不是所有更新。jax.numpy应用所有更新的次序是根据设定的规则使用,或者根据分布式平台并发行性进行处理。

下面举简单例子,说明jax.numpy数组特性,无数组越界检查


import jax

def test():

    array = jax.numpy.arange(10)
    
    print("array =", array)
    print("array[15] = ", array[15])
    
def main():

    test()
    
if __name__ == "__main__":

    main()

运行结果打印输出如下,


array = [0 1 2 3 4 5 6 7 8 9]
array[15] =  9

从结果可见,jax.numy没有越界检查,长度为10的数组,当要打印索引为15的元素时,会打印出数组中最后一个数。这是因为jax.numpy支持为超出范围的索引访问提供一种新的模式对数组进行处理。当然,这同样也会带来一些问题,比如把数组越界的问题隐藏,进而给数组的计算会给使用者带来的意料之外的结果,请一定注意。

下面举简单例子,说明jax.numpy数组特性, 数组更新


import jax

def test():

    array = jax.numpy.arange(10)
    
    print("array =", array)
    print("array[5] = ", array[5])
    
    array_new = array.at[5].set(20)
    
    print("array =", array)
    print("array[5] = ", array[5])
    
    print("array_new =", array_new)
    print("array_new[5] = ", array_new[5])
    
def main():

    test()
    
if __name__ == "__main__":

    main()

运行结果打印输出如下,


array = [0 1 2 3 4 5 6 7 8 9]
array[5] =  5
array = [0 1 2 3 4 5 6 7 8 9]
array[5] =  5
array_new = [ 0  1  2  3  4 20  6  7  8  9]
array_new[5] =  20

可见,使用 array_new = array.at[5].set(20)并不会改变array数组本身,当会创建一个包含更新后的数组副本array_new。

jax.nn神经网络库

jax.nn提供了常用的实现神经网络模型的函数,比如损失函数,激活函数,独热编码函数等等。


图2 jax.nn神经网络组件

有了这些函数,在设计神经网络模型时,不再需要自定义。由于前面多次使用了这些函数,不再赘述。

jax.experimental实验包和jax.example_libraries示例包

这两个包属于实验性质,新的功能,新的API往往在这两个包里 实验,一个提供实验性API,一个提供示例。由于其实验性,随着版本更新,两个包里说提供的功能也会不同。

下面主要介绍jax.experimental.sparse, jax.experimental.optimizers和jax.experimental.stax

jax.experimental.sparse稀松数据处理

jax.experimental.sparse木块的作用是对戏送花数据进行处理,其主要使用BCOO(Batched-coordinate Sparse Array)来进行,并提供了与JAX函数兼容的压缩存储格式,下面举例说明。


import jax
from jax.experimental import sparse

def sparse():

    array = jax.numpy.array([
        [0., 1., 0., 2.],
        [3., 0., 0., 0],
        [0., 0., 4., 0.],
        [0., 0., 0., 5.],
    ])
    
    print("array = ", array)
    
    sparsedArray = jax.experimental.sparse.BCOO.fromdense(array)
    
    print("sparsedArray = ", sparsedArray)
    
def main():

    sparse()
    
if __name__ == "__main__":

    main()

将稀松化的数据转换成普通的矩阵,代码如下,


import jax
from jax.experimental import sparse

def sparse():

    array = jax.numpy.array([
        [0., 1., 0., 2.],
        [3., 0., 0., 0],
        [0., 0., 4., 0.],
    ])
    
    print("array = ", array)
    sparsedArray = jax.experimental.sparse.BCOO.fromdense(array)
    print("--------------------------------------")
    
    denseArray = sparsedArray.todense()
    print("denseArray = ", denseArray)
    print("--------------------------------------")
    
    print("sparsedArray = ", sparsedArray)
    print("sparsedArray.data = ", sparsedArray.data)
    print("--------------------------------------")
    
    print("sparsedArray.indices = ", sparsedArray.indices)
    print("--------------------------------------")
        
    for tuple in sparsedArray.indices:
    
        print(f"array[{tuple[0]}, {tuple[1]}] = ", array[tuple[0], tuple[1]])
        
    print("sparsedArray.ndim = ", sparsedArray.ndim)
    print("sparsedArray.shape = ", sparsedArray.shape)
    print("sparsedArray.dtype = ", sparsedArray.dtype)
    print("sparsedArray.nse = ", sparsedArray.nse)

    
def main():

    sparse()
    
if __name__ == "__main__":

    main()

运行结果打印输出如下,


array =  [[0. 1. 0. 2.]
 [3. 0. 0. 0.]
 [0. 0. 4. 0.]]
--------------------------------------
denseArray =  [[0. 1. 0. 2.]
 [3. 0. 0. 0.]
 [0. 0. 4. 0.]]
--------------------------------------
sparsedArray =  BCOO(float32[3, 4], nse=4)
sparsedArray.data =  [1. 2. 3. 4.]
--------------------------------------
sparsedArray.indices =  [[0 1]
 [0 3]
 [1 0]
 [2 2]]
--------------------------------------
array[0, 1] =  1.0
array[0, 3] =  2.0
array[1, 0] =  3.0
array[2, 2] =  4.0
sparsedArray.ndim =  2
sparsedArray.shape =  (3, 4)
sparsedArray.dtype =  float32
sparsedArray.nse =  4

BCOO格式是标准稀松格式的一种稍作修改的版本,在数据和索引属性中可以看到原始矩阵的表示形式。

  • sparsedArray.data是原始矩阵中所有出现的数值,并以由低到高的顺序排列。
  • sparsedArray.indices是原始矩阵中不为0数值的位置。
  • sparsedArray.ndim原是数组维度个数。
  • sparsedArray.dtype原始矩阵数据类型。
  • sparsedArray.nse原始矩阵不为0的元素个数。

此外,BCOO对象还实现了许多类似于数组的方法,允许在JAX程序中直接使用。下面代码演示如何转置矩阵向量积,


import jax
from jax.experimental import sparse

def sparse():

    array = jax.numpy.array([
        [0., 1., 0., 2.],
        [3., 0., 0., 0],
        [0., 0., 4., 0.],
    ])
    
    print("array = ", array)
    sparsedArray = jax.experimental.sparse.BCOO.fromdense(array)
    print("--------------------------------------")
    
    dot(sparsedArray)
    
def dot(sparsedArray):

    array = jax.numpy.array([
        [1.],
        [2.],
        [3.],
    ])
    
    print("sparsedArray.T = ", sparsedArray)
    print("-----------------------------------------")
    
    dotted = sparsedArray.T @ array
    print("dotted = ", dotted)
    print("-----------------------------------------")
    
    # Normal matrixs required when computing by jax.numpy
    dotted = jax.numpy.dot(sparsedArray.T.todense(), array)
    print("dotted = ", dotted)
    print("-----------------------------------------")
    
    
def main():

    sparse()
    
if __name__ == "__main__":

    main()

运行结果打印输出如下,


array =  [[0. 1. 0. 2.]
 [3. 0. 0. 0.]
 [0. 0. 4. 0.]]
--------------------------------------
sparsedArray.T =  BCOO(float32[3, 4], nse=4)
-----------------------------------------
dotted =  [[ 6.]
 [ 1.]
 [12.]
 [ 2.]]
-----------------------------------------
dotted =  [[ 6.]
 [ 1.]
 [12.]
 [ 2.]]
-----------------------------------------

至于之前使用过的jax.jit()、jax.vmap()、jax.grad()等函数,西松矩阵也可以直接计算,代码如下,


import jax
from jax.experimental import sparse

def sparse():

    array = jax.numpy.array([
        [0., 1., 0., 2.],
        [3., 0., 0., 0],
        [0., 0., 4., 0.],
    ])
    
    print("array = ", array)
    sparsedArray = jax.experimental.sparse.BCOO.fromdense(array)
    print("--------------------------------------")
    
    return sparsedArray
 
@jax.jit
def function(array, sparsedArray):
    
    dotted = sparsedArray.T @ array
    dotted = dotted.sum()
    
    return dotted
    
def main():

    array = jax.numpy.array([
    
        [1.],
        [2.],
        [3.],
    ])
    
    sparsedArray = sparse()
    
    dotted = function(array, sparsedArray)
    
    print("dotted = ", dotted)
    print("---------------------------------")
    
    function_sparsify = jax.experimental.sparse.sparsify(function)
    dotted = function_sparsify(array, sparsedArray)
    
    print("dotted = ", dotted)
    print("---------------------------------")
    
    
if __name__ == "__main__":

    main()

运行结果打印输出如下,


array =  [[0. 1. 0. 2.]
 [3. 0. 0. 0.]
 [0. 0. 4. 0.]]
--------------------------------------
dotted =  21.0
---------------------------------
dotted =  21.0

虽然大多数条件下,jax.numpy函数都能够使用sparse.sparsify进行包装处理,例如dot、transpose、add、mul、abs、neg、reduce_sum以及条件语句。

结论

本章介绍了JAX的库包,简单介绍了库包分类,重点介绍了jax.numy和jax.experimental下的稀松矩阵处理。尤其是稀松矩阵处理,在依赖巨大数据量的深度学习中,节约存储空间是重要的性能优化举措。掌握稀松矩阵处理方法,有利于开发高性能的深度学习应用程序。下一章将通过代码深入了解稀松矩阵处理以及学习jax.example_libraries.optimizers优化器。

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容