过往的章节,一直在使用各种JAX包(package,类似于某些编程语言的类库),比如jax.numpy、jax.random、jax.nn、jax.lax等,在各种代买实战中使用了大量的原生API,也就是说,直接使用这些JAX官方提供API建立各种深度学习模型。
对于这些包尚未专门介绍,现开辟一章专门介绍JAX中这些包。
JAX包种类
为了更好地管理多个模块源文件,Python提供了包的概念。通过查看Python目录,比如MacOS里,路径$HOME/Library/Python/{version}/lib/python/site-packages下就是安装的各种包。比如和jax相关的包,
jax
jax-{version}.dist-info
jaxlib
jaxlib-{version}.dist-info
从物理上看,每个包对应一个文件夹,文件夹第一个文件一般是init.py,打开文件夹,可见文件夹包含多个子文件夹,可用于包含多个模块源文件;
_src
example_libraries
experimental
image
interpreters
lax
lib
nn
numpy
ops
scipy
tools
从逻辑上看,包的本质依然是模块。JAX中可直接使用的包分为下面几大类,
| 包名 | 用途 |
|---|---|
| example_libraries | 使用JAX创建libraries的例子,都是一些短小示例代码。一般不会merge的PR。 |
| experimental | 实验性质库,一旦成熟会合并到模块。 |
| image | 图像操作功能。 |
| experimental | 实验性质库,一旦成熟会合并到模块。 |
| image | 图像操作功能。 |
| lax | lax是一个原语操作库用来增强numpy,像诸如JVP和批量规则,通常就是定义层转换的lax原语。很多原语是对相应XLA操作的浅层封装。尽可能低使用想jax.numpy类库去运算,而避免直接使用jax.lax。 |
| lib | 用于桥接JAX Python前端和XLA后端的工具集。 |
| nn | 神经网络通用函数库。 |
| numpy | 使用jax.lax原语实现的NumPy API库,接近NumPy当不是所有。 |
| ops | 函数操作运算符。 |
| scipy | 统计分析。 |
| tools | 工具类。 |
| random | 伪随机数生成工具类。 |
| tree_util | 树形结构功能函数集。 |
| profiler | 性能分析工具,追踪和衡量时间消耗的工具。 |
| dlpack | 深度学习tensorDeviceArray相关。 |
| flatten_util | Pytree维度操作有关,比如降维。 |
可以说,JAX库多种多样,涵盖了深度学习框架的常用函数。下面举例说明。
jax.numpy数值运算
JAX初衷就是为了取代NumPy成为数值运算通用库,当相对于NumPy还有一些区别,
JAX数组不可变(immutable),所以不能像在NumPy随意改变数组元素,这是为配合jax.jit等XLA包装。当然,JAx也提供变通的方法来“更改”数组元素。比如,JAX提供了一个替代的纯索引函数来“更新”数组,而不是array[i] = x。当然,这种所谓更新其实会创建一个新的数组,原数组保持不变。因此,一些在NumPy里返回数组视图的函数,比如numpy.transpose()和numpy.regpe(),在jax.numpy对应的函数将返回数组副本,而不是原数组,尽管在使用jax.jit()编译操作数组时,这些副本通常可以由XLA优化。
jax.numpy提供了大量函数用于数值运算,下图摘抄了官方文档的一部分,

更多详情可参阅https://jax.readthedocs.io/en/latest/jax.numpy.html
很多Python开发者,比较熟悉NumPy的API,关于数组的差异可以简单介绍一下两者对应的操作,简单来说,由于jax.numy的数组一旦定义赋值无法更改,所以引入array.at[i].set(value)的方法,具体数组处理方式如下表,
| JAX | NumPy |
|---|---|
| array.at[index].set(value) | array[index] = value |
| array.at[index].add(value) | array[index] += value |
| array.at[index].multiply(value) | array[index] *= value |
| array.at[index].divide(value) | array[index] /= value |
| array.at[index].power(value) | array[index] **= value |
| array.at[index].min(value) | array[index] = minimum(array[index], value) |
| array.at[index].max(value) | array[index] = maximum(array[index], value) |
| array.at[index].get() | value = array[index] |
如前面所述,array.at[]表达式不会改变原数组,相反,会返回一个修改过的array副本,也就是说一个在内存中新开辟存储单元的数组。但是,在jax.jit()编译函数时,如array = array.at[index].set(value)这样的表达式肯定会被广泛使用。
与NumPy就地操作不同,如array[index] += value不同,如果多个索引引用同一个位置,则将应用所有更新,而NumPy只会应用最后一个更新,而不是所有更新。jax.numpy应用所有更新的次序是根据设定的规则使用,或者根据分布式平台并发行性进行处理。
下面举简单例子,说明jax.numpy数组特性,无数组越界检查
import jax
def test():
array = jax.numpy.arange(10)
print("array =", array)
print("array[15] = ", array[15])
def main():
test()
if __name__ == "__main__":
main()
运行结果打印输出如下,
array = [0 1 2 3 4 5 6 7 8 9]
array[15] = 9
从结果可见,jax.numy没有越界检查,长度为10的数组,当要打印索引为15的元素时,会打印出数组中最后一个数。这是因为jax.numpy支持为超出范围的索引访问提供一种新的模式对数组进行处理。当然,这同样也会带来一些问题,比如把数组越界的问题隐藏,进而给数组的计算会给使用者带来的意料之外的结果,请一定注意。
下面举简单例子,说明jax.numpy数组特性, 数组更新
import jax
def test():
array = jax.numpy.arange(10)
print("array =", array)
print("array[5] = ", array[5])
array_new = array.at[5].set(20)
print("array =", array)
print("array[5] = ", array[5])
print("array_new =", array_new)
print("array_new[5] = ", array_new[5])
def main():
test()
if __name__ == "__main__":
main()
运行结果打印输出如下,
array = [0 1 2 3 4 5 6 7 8 9]
array[5] = 5
array = [0 1 2 3 4 5 6 7 8 9]
array[5] = 5
array_new = [ 0 1 2 3 4 20 6 7 8 9]
array_new[5] = 20
可见,使用 array_new = array.at[5].set(20)并不会改变array数组本身,当会创建一个包含更新后的数组副本array_new。
jax.nn神经网络库
jax.nn提供了常用的实现神经网络模型的函数,比如损失函数,激活函数,独热编码函数等等。

有了这些函数,在设计神经网络模型时,不再需要自定义。由于前面多次使用了这些函数,不再赘述。
jax.experimental实验包和jax.example_libraries示例包
这两个包属于实验性质,新的功能,新的API往往在这两个包里 实验,一个提供实验性API,一个提供示例。由于其实验性,随着版本更新,两个包里说提供的功能也会不同。
下面主要介绍jax.experimental.sparse, jax.experimental.optimizers和jax.experimental.stax
jax.experimental.sparse稀松数据处理
jax.experimental.sparse木块的作用是对戏送花数据进行处理,其主要使用BCOO(Batched-coordinate Sparse Array)来进行,并提供了与JAX函数兼容的压缩存储格式,下面举例说明。
import jax
from jax.experimental import sparse
def sparse():
array = jax.numpy.array([
[0., 1., 0., 2.],
[3., 0., 0., 0],
[0., 0., 4., 0.],
[0., 0., 0., 5.],
])
print("array = ", array)
sparsedArray = jax.experimental.sparse.BCOO.fromdense(array)
print("sparsedArray = ", sparsedArray)
def main():
sparse()
if __name__ == "__main__":
main()
将稀松化的数据转换成普通的矩阵,代码如下,
import jax
from jax.experimental import sparse
def sparse():
array = jax.numpy.array([
[0., 1., 0., 2.],
[3., 0., 0., 0],
[0., 0., 4., 0.],
])
print("array = ", array)
sparsedArray = jax.experimental.sparse.BCOO.fromdense(array)
print("--------------------------------------")
denseArray = sparsedArray.todense()
print("denseArray = ", denseArray)
print("--------------------------------------")
print("sparsedArray = ", sparsedArray)
print("sparsedArray.data = ", sparsedArray.data)
print("--------------------------------------")
print("sparsedArray.indices = ", sparsedArray.indices)
print("--------------------------------------")
for tuple in sparsedArray.indices:
print(f"array[{tuple[0]}, {tuple[1]}] = ", array[tuple[0], tuple[1]])
print("sparsedArray.ndim = ", sparsedArray.ndim)
print("sparsedArray.shape = ", sparsedArray.shape)
print("sparsedArray.dtype = ", sparsedArray.dtype)
print("sparsedArray.nse = ", sparsedArray.nse)
def main():
sparse()
if __name__ == "__main__":
main()
运行结果打印输出如下,
array = [[0. 1. 0. 2.]
[3. 0. 0. 0.]
[0. 0. 4. 0.]]
--------------------------------------
denseArray = [[0. 1. 0. 2.]
[3. 0. 0. 0.]
[0. 0. 4. 0.]]
--------------------------------------
sparsedArray = BCOO(float32[3, 4], nse=4)
sparsedArray.data = [1. 2. 3. 4.]
--------------------------------------
sparsedArray.indices = [[0 1]
[0 3]
[1 0]
[2 2]]
--------------------------------------
array[0, 1] = 1.0
array[0, 3] = 2.0
array[1, 0] = 3.0
array[2, 2] = 4.0
sparsedArray.ndim = 2
sparsedArray.shape = (3, 4)
sparsedArray.dtype = float32
sparsedArray.nse = 4
BCOO格式是标准稀松格式的一种稍作修改的版本,在数据和索引属性中可以看到原始矩阵的表示形式。
- sparsedArray.data是原始矩阵中所有出现的数值,并以由低到高的顺序排列。
- sparsedArray.indices是原始矩阵中不为0数值的位置。
- sparsedArray.ndim原是数组维度个数。
- sparsedArray.dtype原始矩阵数据类型。
- sparsedArray.nse原始矩阵不为0的元素个数。
此外,BCOO对象还实现了许多类似于数组的方法,允许在JAX程序中直接使用。下面代码演示如何转置矩阵向量积,
import jax
from jax.experimental import sparse
def sparse():
array = jax.numpy.array([
[0., 1., 0., 2.],
[3., 0., 0., 0],
[0., 0., 4., 0.],
])
print("array = ", array)
sparsedArray = jax.experimental.sparse.BCOO.fromdense(array)
print("--------------------------------------")
dot(sparsedArray)
def dot(sparsedArray):
array = jax.numpy.array([
[1.],
[2.],
[3.],
])
print("sparsedArray.T = ", sparsedArray)
print("-----------------------------------------")
dotted = sparsedArray.T @ array
print("dotted = ", dotted)
print("-----------------------------------------")
# Normal matrixs required when computing by jax.numpy
dotted = jax.numpy.dot(sparsedArray.T.todense(), array)
print("dotted = ", dotted)
print("-----------------------------------------")
def main():
sparse()
if __name__ == "__main__":
main()
运行结果打印输出如下,
array = [[0. 1. 0. 2.]
[3. 0. 0. 0.]
[0. 0. 4. 0.]]
--------------------------------------
sparsedArray.T = BCOO(float32[3, 4], nse=4)
-----------------------------------------
dotted = [[ 6.]
[ 1.]
[12.]
[ 2.]]
-----------------------------------------
dotted = [[ 6.]
[ 1.]
[12.]
[ 2.]]
-----------------------------------------
至于之前使用过的jax.jit()、jax.vmap()、jax.grad()等函数,西松矩阵也可以直接计算,代码如下,
import jax
from jax.experimental import sparse
def sparse():
array = jax.numpy.array([
[0., 1., 0., 2.],
[3., 0., 0., 0],
[0., 0., 4., 0.],
])
print("array = ", array)
sparsedArray = jax.experimental.sparse.BCOO.fromdense(array)
print("--------------------------------------")
return sparsedArray
@jax.jit
def function(array, sparsedArray):
dotted = sparsedArray.T @ array
dotted = dotted.sum()
return dotted
def main():
array = jax.numpy.array([
[1.],
[2.],
[3.],
])
sparsedArray = sparse()
dotted = function(array, sparsedArray)
print("dotted = ", dotted)
print("---------------------------------")
function_sparsify = jax.experimental.sparse.sparsify(function)
dotted = function_sparsify(array, sparsedArray)
print("dotted = ", dotted)
print("---------------------------------")
if __name__ == "__main__":
main()
运行结果打印输出如下,
array = [[0. 1. 0. 2.]
[3. 0. 0. 0.]
[0. 0. 4. 0.]]
--------------------------------------
dotted = 21.0
---------------------------------
dotted = 21.0
虽然大多数条件下,jax.numpy函数都能够使用sparse.sparsify进行包装处理,例如dot、transpose、add、mul、abs、neg、reduce_sum以及条件语句。
结论
本章介绍了JAX的库包,简单介绍了库包分类,重点介绍了jax.numy和jax.experimental下的稀松矩阵处理。尤其是稀松矩阵处理,在依赖巨大数据量的深度学习中,节约存储空间是重要的性能优化举措。掌握稀松矩阵处理方法,有利于开发高性能的深度学习应用程序。下一章将通过代码深入了解稀松矩阵处理以及学习jax.example_libraries.optimizers优化器。