第22章 XLA和JAX特性简介

JAX简单来说就是支持GPU计算(比如CUDA)、自动求导(Autograd)的NumPy。在Python,进行数值计算离不开NumPy这个基础数值运算库。但NumPy不支持GPU或者其他硬件加速器,也不对Backwards Propagation的内置支持,再加上Python本身的速度限制,很少会在生产环境下直接使用NumPy训练或部署深度学习模型。这也是为什么会出现PyTorch、TensorFlow、Caffe等深度学习框架的原因。

JAX是机器学习框架领域的新生力量,其建立在加速线性代数XLA(Accelerated Linear Algebra)。具有更快的高阶求导,对GPU和TPU更好的支持。

图1 JAX

通过XLA,JAX能对原生Python和NumPy进行自动微分。

XLA与JAX

JAX使用XLA在GPU和TPU上进行编译和运行NumPy程序,XLA是将JAX转化为加速器支持操作的中坚力量。

XLA作为深度学习编译器,其长期以来作为Google在深度学习领域的一个重要特性被开发,特别是hi作为TensorFlow 2.0背后支持之一,XLA也从实验特性变成了默认打开的特性。

XLA如何运行

XLA是一种针对特定领域的线性代数编译器,比如能够加快TensorFlow模型的运行速度,而且完全不需要更改源码。一般编译器都服务于如下目的,

  • 提高代码执行速度。
  • 优化存储使用。

XLA也不例外。XLA功能主要体现在以下几个方面,

  • 融合可组合算子从而提高运算速度。

XLA通过对TensorFlow运行时计算图进行分析,将多个低级算子进行融合,从而生成高效的机器码。如下图,计算图中许多算子都是逐元素(element-wise)地计算,所以融合(Fuse)到一个element-wise的循环计算kernel中,从而减少了kernel加载,提高计算性能。

图2 算子及计算图

例如,上图中,bias向量里某单个元素与从matmul出来的结果某单个元素进行加法计算,对于其后的ReLU函数,这个加法计算结果是一个可与0进行比较的单元素。比较的结果使用输入值的幂进行指数运算和除法运算,从而产生softmax的结果。

将多个kernel融合为一个kernel,因此减少了加载kernel的时间消耗。

因此,XLA对模型的性能提升主要是将多个连续的element-wise算子融合为一个算子。

  • 减少内存占用

在上面elemnt-wise计算整个过程,不需要为matmul、add、和ReLU等算子使用中间数组而开辟内存空间。通过将这些算子融合,可以减少申请这些酸子间的中间结果所占用的内存。

  • 减少模型可执行文件大小

对移动设备场景,XLA可以减少模型的可执行文件的大小。通过AOT(Ahead of Time)编译将整个计算图生成轻量级的机器码,这些机器码实现了计算图中的各个操作。

在模型运行时,不需要在一个完整的运行环境,因为计算图实际执行时的操作被转换编译为设备代码。

  • 方便支持不同硬件后端

当传统的深度学习应用支持一种新的设备时(比如不同的xPU指令集),需要将所有的kernels(OPs)再重新实现一遍,需要巨大的工作量。而通过XLA则需要很小的工作量,因为XLA的的算子都是操作原语(低级算子,primitive or atomic action,不可再分、不可中断的操作),数量少且容易实现,XLA会自动地将复杂的算子拆解为primitive算子。

可以说,XLA在TensorFlow 2.0上成功的实践,移植到到JAX,让JAX站在了成功的肩膀上。

XLA如何工作

XLA的输入语言是HLO IR(High Level Optimizer Intermediate Representation)。XLA将HLO描述的计算图(计算流程)编译成真对各种特定后端的机器指令。

整个流程如下图,

  • 首先XLA对输入的HLO计算图进行目标设备无关的优化和分析,如CSE(Common Subexpression Elimination,公共子表达式删除,通过删除运算中等价子表达式的一种代码优化方法)、算子融合,运行时的内存分配分析。输出为优化后的HLO计算图。

  • 然后将HLO计算图发送到后端(Backend),后端结合特定的硬件属性对HLO计算图进行进一步的HLO级优化,例如对某些操作或其组合进行模式匹配从而优化计算库调用。

  • 最后,后端将HLO IR转化为LLVM IR,LLVM再进行低级优化并生成机器码。

JAX一般特性

JAX可以自动微分原生Python和NumPy函数。可以通过Python的大部分功能(包括循环、if、递归和闭包)进行微分。支持反向模式和正向模式微分,并且两者可以以任意顺序组成。

JAX的新功能是使用XLA再诸如GPU和TPU的加速器上编译和运行NumPy代码。默认情况下,编译器是后台运行的,而库调用将得到及时的编译和运行。但JAX允许使用单功能API将Python函数编译为XLA优化的内核。编译和自动微分可以任意组合,因此无需离开Python即可表达复杂的算法并获得最佳性能。

利用JIT加速程序运行

虽然精心编写的NumPy代码运行起来效率很高,但对于现代机器学习来说,还是希望这些代码运行的尽可能快。这一般通过在GPU或TPU(Tensor Processing Unit,张量处理单元)等不同的的“加速器”上运行代码来实现。

JAX提供了JIT(Just In Time即时)编译器,采用标准的Python、NumPy函数,经编译后可以在加速器上高效运行。编译函数还可以避免Python解释器的开销。总的来说,jax.jit可以显著加速代码运行,且基本上没有额外的编码开销,需要做的就是使用JAX编译函数。在使用jax.jit时,即使是微小的神经网络也可以实现相当惊人的加速效果。

下面演示使用jit加速的程序。


import time
import jax

def selu(inputs, alpha = 1.67, theta = 1.05):
    
    return jax.numpy.where(inputs > 0, inputs, alpha * jax.numpy.exp(inputs) - alpha) * theta

@jax.jit
def selu1(inputs, alpha = 1.67, theta = 1.05):
    
    return jax.numpy.where(inputs > 0, inputs, alpha * jax.numpy.exp(inputs) - alpha) * theta

def test():
    
    key = jax.random.PRNGKey(15)
    inputs = jax.random.normal(key = key, shape = (1000000,))
    
    start = time.time()
    
    selu(inputs = inputs)
    
    end = time.time()
    
    print(f"Time consumed by selu: %.2f seconds" % (end - start))
    
    jit_selu = jax.jit(selu)
    
    start = time.time()
    
    jit_selu(inputs)
    
    end = time.time()
    
    print(f"Time consumed by jit_selu: %.2f seconds" % (end - start))
    
    start = time.time()
    
    selu1(inputs)
    
    end = time.time()
    
    print("Time consumed by selu1: %.2f seconds" % (end - start))
    
if __name__ == "__main__":
    
    test()

第一个selu时Python的原生函数,第二个selu时通过jax.jit包装(wrap)后的函数,第三个是通过@jax.jit修饰的selu1函数。运行结果打印输出如下,


Time consumed by selu: 0.43 seconds
Time consumed by jit_selu: 0.10 seconds
Time consumed by selu1: 0.10 seconds

可见,相同的一段代码在不同的运行机制下速度有极大的不同。由于JAX充分利用了jit的特性,通过jit包装或修饰,使得函数在第一次调用后就被jit compile缓存。

自动求导器

除了评估数值函数,还可以使用自动求导。在JAX中,使用jax.grad函数来计算导数。jax.grad接受一个函数并返回一个新函数,该函数计算原始函数的梯度。要使用梯度下降,可以根据神经网络的参数计算损失函数的梯度,即使用jax.grad(loss_function)来计算损失函数的梯度。

下面演示对一个计算函数使用jax.grad进行求导。


import jax

def cumulate_exponents(inputs):

    # Sigmoid Activation Funciton
    # f(x) = 1.0/(1.0 + e⁻ˣ)
    sigmoids = 1. / (1. + jax.numpy.exp(-inputs))
    #
    # Σsigmoids
    return jax.numpy.sum(sigmoids)

def test():

    # f'(x) = e⁻ˣ / (1 + e⁻ˣ)² = (1 - f(x)) x f(x)
    # derivative_cumulate_exponents = f'(x)
    # [0.25, ]
    derivative_cumulate_exponents = jax.grad(cumulate_exponents)

    # inputs = [0. 1. 2.]
    # sigmoids = [0.5, 0.731058578630005, 0.880797077977882]
    inputs = jax.numpy.arange(3.)

    # Σsigmoids = Σ[0.5, 0.731058578630005, 0.880797077977882] = 2.1118555
    outputs = cumulate_exponents(inputs)
    derivative_outputs = derivative_cumulate_exponents(inputs)

    print("inputs =", inputs, ", outputs =", outputs, ", derivative_outputs =", derivative_outputs)

if __name__ == "__main__":

    test()

首先生成一个数值序列,对数值公式(其实是sigmoid激活函数)求和,事后使用自动求导器进行求导。注意,原函数求和后输出一个值,经过了求导,按照权重,又分配给3个数值。
运行结果打印输出如下,


inputs = [0. 1. 2.] , outputs = 2.1118555 , derivative_outputs = [0.25       0.19661197 0.10499357]

注意,按照jax的要求,jax.numpy.arange需要输入浮点型。

下面演示共同使用jax.jit函数和jax.grad函数的方法。


import jax
import time

def cumulate_exponents1(inputs):
    
    return jax.numpy.sum(1. / (1. + jax.numpy.exp(-inputs)))

def cumulate_exponents2(inputs):
    
    return jax.numpy.sum(1. / (1. + jax.numpy.exp(-inputs)))

def cumulate_exponents3(inputs):
    
    return jax.numpy.sum(1. / (1. + jax.numpy.exp(-inputs)))

@jax.jit
def cumulate_exponents_jit(inputs):
    
    return jax.numpy.sum(1. / (1. + jax.numpy.exp(-inputs)))

def test():
    
    inputs = jax.numpy.arange(1024.)
    
    start = time.time()
    
    cumulate_exponents_jit_derivative = jax.grad(cumulate_exponents_jit)
    
    print("cumulate_exponents_jit_derivative(inputs) = ", cumulate_exponents_jit_derivative(inputs))
    
    end = time.time()
    
    print("Time consumed: %.2f" % (end - start))
    print("------------------------------")
    
    start = time.time()
    
    cumulate_exponents1_derivative = jax.grad(cumulate_exponents1)
    
    print("cumulate_exponents1_derivative(inputs) = ", cumulate_exponents1_derivative(inputs))

    end = time.time()
    
    print("Time consumed: %.2f" % (end - start))
    print("------------------------------")
    
    start = time.time()
    
    cumulate_exponents2_derivative = jax.grad(cumulate_exponents2)
    cumulate_exponents2_derivative_jit = jax.jit(cumulate_exponents2_derivative)
    
    print("cumulate_exponents2_derivative_jit(inputs) = ", cumulate_exponents2_derivative_jit(inputs))

    end = time.time()
    
    print("Time consumed: %.2f" % (end - start))
    
if __name__ == "__main__":
    
    test()

运行结果打印输出如下,


cumulate_exponents_jit_derivative(inputs) =  [0.25       0.19661197 0.10499357 ... 0.         0.         0.        ]
Time consumed: 0.07
------------------------------
cumulate_exponents1_derivative(inputs) =  [0.25       0.19661197 0.10499357 ... 0.         0.         0.        ]
Time consumed: 0.19
------------------------------
cumulate_exponents2_derivative_jit(inputs) =  [0.25       0.19661197 0.10499357 ... 0.         0.         0.        ]
Time consumed: 0.03

可见,使用jit包装函数极大地提高了运行速度,而且当包装顺序调整后,花费时间只有原始函数的1/6,提升极大。

vmap函数实现自动向量化映射

JAX的API里还有另外一个重要转换函数vmap,即向量化映射。字面意思是创建函数,映射到以参数为轴函数。该函数具有沿着参数轴映射函数的熟悉语义(Familiar Semantics),但是不是将循环保留在外部,而是将循环推入函数的原始操作中以提高性能。当与jax.jit结合时,更能提高计算速度。

在实践中,当训练现代机器学习模型是,可以执行“小批量”的梯度下降,在梯度下降的每个步骤中,对一小批示例中的损失梯度求平均值;当示例的数据适中时,这样做完全没有问题,但当数据过多时,这样做会使得JAX在计算时消耗大量的时间。

解决的办法就是JAX额外提供了jax.vmap,该函数可以对函数进行“向量化”处理,也就是说该函数允许在输入的某个轴上并行计算函数的输出。简单来说,就是可应用jax.vmap函数向量化并立即获得损失函数梯度的版本,该版本适用于小批量示例。

下面用代码说明jax.vmap方法。


import jax
import time

def cumulate_exponents(inputs):
    
    return jax.numpy.sum(1. / (1. + jax.numpy.exp(-inputs)))

def test():
    
    start = time.time()
    inputs = jax.numpy.arange(1024000.)
    
    cumulate_exponents1_derivative = jax.grad(cumulate_exponents)
    cumulate_exponents1_derivative(inputs)
    
    end = time.time()
    
    print("Time consumed: %.2f" % (end - start), "when function executing cumulate_exponents1_derivative")
    
    start = time.time()
    inputs = jax.numpy.arange(1024000.)
    
    cumulate_exponents2_derivative = jax.vmap(jax.grad(cumulate_exponents))
    cumulate_exponents2_derivative(inputs)
    
    end = time.time()
    
    print("Time consumed: %.2f" % (end - start), "when function executing cumulate_exponents2_derivative")

    start = time.time()
    inputs = jax.numpy.arange(1024000.)
    
    cumulate_exponents3_derivative = jax.jit(jax.vmap(jax.grad(cumulate_exponents)))
    
    cumulate_exponents3_derivative(inputs)
    
    end = time.time()
    
    print("Time consumed: %.2f" % (end - start), "when function executing cumulate_exponents3_derivative")

    
if __name__ == "__main__":
    
    test()

运行结果打印输出如下,


Time consumed: 0.29 when function executing cumulate_exponents1_derivative
Time consumed: 0.22 when function executing cumulate_exponents2_derivative
Time consumed: 0.04 when function executing cumulate_exponents3_derivative

可以看到,随着加载更多JAX的特性函数,计算时间依次递减。

除了jax.grad、jax.jit以及jax.vmap,在介绍另外两个函数,

  • in_axes,是一个元组或整数,该函数高速JAX函数参数应该对哪些轴并行化。元组应该与vmap函数的参数数量相同,或者只有一个参数是为整数。比如,使用 (None, 0, 0)是指不在第一个参数(argument)上并行化,并在第二个参数和第三个参数的第一个(索引0)维度上并行化。
  • out_axes,与in_axes类似。指定了函数输出的哪些轴并行化。

结论

本章探讨了JAX和XLA的特性,以及XLA运行机制。另外,通过代码说明jax.grad、jax.jit、jax.vmap的功能和用法。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,012评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,628评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,653评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,485评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,574评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,590评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,596评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,340评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,794评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,102评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,276评论 1 344
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,940评论 5 339
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,583评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,201评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,441评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,173评论 2 366
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,136评论 2 352

推荐阅读更多精彩内容