第45章 JAX自定义函数的原语规则

本章起,学习JAX创建函数的基本规则。JAX的基本规则称为“原语(Primitives)”,原语一次来自于操作系统,指的是执行过程中不可被打断的基本操作。可以理解为一段代码在执行过程中不能被打断,像原子一样不可分割。

接下来几章学习的目的就是为了了解如何利用和遵循“原语“去自定义操作或函数。

JAX函数基本规则

之前用过的jit、grad、vmap等函数,其编译或运行时,JAX会在内部通过转换的方式,实现一些较为复杂的计算。这些函数都是使用了通用的JAX内部机制,将单一维度的函数转换成所需要的多维度函数。

这些操作有一个非常重的需要求就是对数据的属性进行检查,即要求函数说使用数据必须时刻被追踪(Traceable)的。这是由于JAX的对函数进行转换时并不是对具体的一个参数或者具体的某个值进行处理,而是调用参数对象的抽象值。JAX捕获参数的类型和形状,例如ShapedArray(float32[2, 2]),而不是具体的值。

JAX自身预定义的函数,例如add、matmul、sin和cos等数值计算函数在实现时严格遵循了JAX的编译及运行机制规则。自定义函数在组合使用这些函数完成计算时,同样也可以被包装,从而快捷底完成一些较为复杂的计算。

接下来从几个方面介绍JAX原语必须遵循的一些基本规则,以便自定义的JAX函数能够执行转换。

使用已有原语

自定义新函数的简单方法就是使用JAX原语编写,或者使用JAX原语编写其他函数。比如在jax.lax模块中定义的函数。代码如下,


import jax

def multiply_add(a, b, c):
    
    d = jax.lax.mul(a, b)
    e = jax.lax.add(d, c)
    
    return e

def main():
    
    a = 2
    b = 4
    c = 6
    
    e = multiply_add(a, b, c)
    
    print(f"e = {e}")
    
if __name__ == "__main__":
    
    main()

运行结果打印输出如下,


e = 14

从上面代码可以看到,如果需要自定义一个新的符合JAX代码规则的函数,最好的方法就是使用现有的JAX原语。

自定义前向JVP和反向VJP

前面章节介绍过,JAX支持不同模式的自动求导。jax.grad()默认采取了方向模式自动求导。通过jax.jvp和jax.vjp可以指定自动求导模式。

  • jax.jvp:Jacobian-Vector Product,通过前向(forward-mode) 模式自动求导。根据原函数f、输入x和dx计算结果f(x)和d f(x)。在函数输入参数数量少于或等于输出参数数量的情况下,前向模式自动求导比反向模式更省内存,内存利用率更具优势。
  • jax.vjp:Vector-Jacobian Product,反向(reverse-mode)模式自动求导。根据原函数f、输入x计算函数结果f(x),并生成梯度函数。梯度函数中输入是df(x),输出是dx。
jax.jvp计算

下面通过一个例子理解jax.jvp的计算。


import jax

def function(a, b):
    
    return a ** 3 + b ** 2

def test():
    
    a = 4.
    b = 6.
    
    result = function(a, b)
    
    print("result = ", result)
    print("-----------------------")
    
    function_grad = jax.grad(function)
    result = function_grad(a, b)
    
    print("result = ", result)
    print("-----------------------")
    
    function_grad = jax.grad(function, argnums = [0, 1])
    result = function_grad(a, b)
    
    print("result = ", result)

def main():
    
    test()
    
if __name__ == "__main__":
    
    main()

运行结果打印输出如下,


result =  100.0
-----------------------
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
result =  48.0
-----------------------
result =  (Array(48., dtype=float32, weak_type=True), Array(12., dtype=float32, weak_type=True))

这是一个简单的函数求导,首选计算了函数值,之后计算求导函数的值。下面换一种写法,通过自定义的求导函数来输出结果。代码如下,


import jax

# Attribute that the derivative of this funciton will be customized
@jax.custom_jvp
def function (a, b):
    
    return a ** 3 + b ** 2

@function.defjvp
def function_defjvp (primals, tangents):
    
    print("primals = ", primals)
    print("tangents = ", tangents)
    
    a, b = primals
    a_dot, b_dot = tangents
    
    primal_output = function(a, b)
    tangent_output = 3 * a_dot ** 2 + 0
    
    return primal_output, tangent_output

def test():
    
    a = 4.
    b = 6.
    
    result, result_dot = jax.jvp(function, (a, b), (a, b))
    
    print(f"result = {result}, result_dot = {result_dot}")
    
def main():
    
    test()
    
if __name__ == "__main__":
    
    main()
  

运行结果打印输出如下,


primals =  (4.0, 6.0)
tangents =  (4.0, 6.0)
result = 100.0, result_dot = 48.0

通过结果可知,自定义了正向计算和正向求导计算方法。其中function导函数定义成了
f'\left( a, b \right) = 3a^{2} + 0

这是手动计算导函数的结果。当然,这种自定义jvp的方式给了jvp极大的弹性,已不仅仅局限于求导了,可以任意指定算法,而结果也不仅限于导函数的值。

因此,通过自定义的结果即可以显式地展示求导后的值。基于上面的代码,再举一个例子进行说明自定义过程。代码如下,


import jax

def function(a, b):
    
    return a ** 3 + b ** 2

@jax.custom_jvp
def function1(a, b):
    
    return a ** 3 + b ** 2

@function1.defjvp
def function1_jvp(primals, tangents):
    
    a, b = primals
    a_dot, b_dot = tangents
    
    primal_output = function1(a, b)
    tangent_output = a_dot + b_dot
    
    return primal_output, tangent_output

def test():
    
    a = 4.
    b = 6.
    
    function_grad = jax.grad(function, argnums = [0, 1])
    result = function_grad(a, b)
     
    print("Derivative of orginal function =", result)
    print("--------------------------")
    
    function1_grad = jax.grad(function1, argnums = [0, 1])
    result = function1_grad(a, b)
     
    print("Derivative of customized JVP function =", result)
    
def main():
    
    test()
    
if __name__ == "__main__":
    
    main()

运行结果打印输出如下,


Derivative of orginal function = (Array(48., dtype=float32, weak_type=True), Array(12., dtype=float32, weak_type=True))
--------------------------
Derivative of customized JVP function = (Array(1., dtype=float32, weak_type=True), Array(1., dtype=float32, weak_type=True))

从结果来看,对于同一个函数,通过自定义JVP的求导结果和原始函数的求导结果并不一致。这是因为在自定义JVP求导方法后,此时的grad函数的计算规则有了变化,不再是对function函数求导,而是简单的a_dot + b_dot。

jx.vjp计算

Jax.vjp计算过程和jax.jvp相似,也是需要预先定义好输入的求导算法。代码如下所示,


import jax

@jax.custom_vjp
def function(a, b):
    
    return 2 ** a ** 2 + b ** 2

def function_forward(a, b):
    
    result = function(a, b)
    
    return result, (a, b)

def function_backward(result, gradient):
    
    b, a = result
    
    return (b, a)

def register():
    
    function.defvjp(fwd = function_forward, bwd = function_backward)
    
def test():
    
    register()
    
    a = 4.
    b = 6.
    
    grad_function = jax.grad(fun = function)
    result = grad_function(a, b)
    
    print("result = ", result)
    print("-----------------------")
    
    grad_function = jax.grad(function, argnums = [0, 1])
    result = grad_function(a, b)
    
    print("result = ", result)
    print("-----------------------")

def main():
    
    test()
    
if __name__ == "__main__":
    
    main()

运行结果打印输出如下,


result =  4.0
-----------------------
result =  (4.0, 6.0)
-----------------------

JAX使用JVP的目的之一就是为了求导结果的稳定性。例如,一个JAX函数的想完成下列数学函数的计算

f\left( x \right) = \log{\left( 1 + e^{x} \right)}

使用JAX完成函数的代码如下所示,


import jax

def logxp(x):
    
    return jax.numpy.log(1. + jax.numpy.exp(x))

def test():
    
    logxp_jit = jax.jit(logxp)
    
    result = logxp_jit(3.)
    
    print("result = ", result)
    print("-------------------")
    
    logxp_grad = jax.grad(logxp)
    log_grad_jit = jax.jit(logxp_grad)
    
    result = log_grad_jit(3.)
    
    print("result = ", result)
    print("-------------------")
    
    log_grad_jit_vmap = jax.vmap(log_grad_jit)
    
    result = log_grad_jit_vmap(jax.numpy.arange(4.))
    
    print("result = ", result)
    print("-------------------")
    
    result = logxp_grad(99.)
    
    print("result = ", result)
    print("-------------------")
    
    
def main():
    
    test()
    
if __name__ == "__main__":
    
    main()

运行结果打印输出如下,


result =  3.0485873
-------------------
result =  0.95257413
-------------------
result =  [0.5        0.7310586  0.8807971  0.95257413]
-------------------
result =  nan
-------------------

注意,当输入值为99.时,输出结果为nan,明显不对。原因在于计算导数时,

d\left( x \right) = \frac{e^{x}}{1 + e^{x}} e^{100} = inf

即,由于e^{x}在x为00时,值为inf,时的最终计算结果为nan。

为了解决这个问题,需要向JAX中场地自定义求导规则和方法,代码如下,


import jax

def logxp(x):
    
    return jax.numpy.log(1. + jax.numpy.exp(x))

@jax.custom_jvp
def logxp1(x):
    
    return jax.numpy.log(1. + jax.numpy.exp(x))

@logxp1.defjvp
def logxp1_jvp(primals, tangents):
    
    x, = primals
    x_dot, = tangents
    
    result = logxp(x)
    result_dot = (1 - 1 / (1 + jax.numpy.exp(x))) * x_dot
    
    return result, result_dot

def test():
    
    logxp_jit = jax.jit(logxp)
    
    result = logxp_jit(3.)
    
    print("result = ", result)
    print("-------------------")
    
    logxp_grad = jax.grad(logxp)
    log_grad_jit = jax.jit(logxp_grad)
    
    result = log_grad_jit(3.)
    
    print("result = ", result)
    print("-------------------")
    
    log_grad_jit_vmap = jax.vmap(log_grad_jit)
    
    result = log_grad_jit_vmap(jax.numpy.arange(4.))
    
    print("result = ", result)
    print("-------------------")
    
    result = logxp_grad(99.)
    
    print("result = ", result)
    print("-------------------")
    
    logxp_grad = jax.grad(logxp1)
    result = logxp_grad(99.)
    
    print("result = ", result)
    print("-------------------")
    
def main():
    
    test()
    
if __name__ == "__main__":
    
    main()

再次运行,结果打印输出如下,


result =  3.0485873
-------------------
result =  0.95257413
-------------------
result =  [0.5        0.7310586  0.8807971  0.95257413]
-------------------
result =  nan
-------------------
result =  1.0
-------------------

深入理解jax.custom_jvp和jax.custom_vjp

上一节介绍了jax.custom_jvp和jax.custom_vjp,下面介绍这两种自定义函数的高级用法。

使用jax.custom_jvp

下面通过一个简单例子来说说明custom_jvp基本使用方法,目标是使用custom_jvp定义一个前向函数。代码如下所示,

                                                                                                      
import jax

@jax.custom_jvp
def function(x):
    
    return 2 * x ** 2 + 3 * x

def function_jvp(primals, tangents):
    
    x, = primals
    x_dot, = tangents
    
    result = function(x)
    result_dot = 4 * x + 3 * x_dot
    
    return result, result_dot
    
def register():
    
    function.defjvp(function_jvp)
    
def main():
    
    register()
    
    a = 5.
    
    result = function(a)
    
    print("result = ", result)
    print("-------------------------")
    
    function_grad = jax.grad(function)
    
    result = function_grad(a)
    
    print("result = ", result)
    print("-------------------------")
    
    b = 6.
    
    result, result_dot = jax.jvp(function, (a,))
    
    print(f"result = {result}, result_dot = {result_dot}")
    
if __name__ == "__main__":
    
    main()

其中,

  • x, = primals用于定义输入的参数
  • t, = tangents用于标识自定义目标求导函数4 * x + 3 * x_dot。

运行结果打印输出如下,


result =  65.0
-------------------------
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
result =  3.0
-------------------------
result = 65.0, result_dot = 38.0

其中,

  • 第1行结果为65.0,是原函数function的输出值。
  • 第2行结果为3.0,是自定义求导函数4 * x + 3 * x_dot的偏导数值,公式如下,
    d\left( x \right) = \frac {d (4x + 3 x \_dot)}{d\left(x\_dot \right)} = x\_dot = 3
  • 第3行是jax.jvp直接输出的原函数以及自定义的求导函数的计算值。

也就是说,从一个原始函数function开始,通过function_jvp定义了所需要的求导函数形式,之后通过function.defjvp对其进行注册,从而使得JAX能够知道原函数以及需要求导函数的形式。另外,如上一节的例子,可以使用defjvp修饰符注册。下面使用defjvp修饰符进行自定义求导函数。代码如下所示,


import jax

@jax.custom_jvp
def function(x, y):
    
    return x ** 3 + y ** 2

@function.defjvp
def function_jvp(primals, tangents):
    
    x, y = primals
    x_dot, y_dot = tangents
    
    primal_output = function(x, y)
    tangent_output = x * y_dot + y * x_dot
    
    return primal_output, tangent_output

def main():
    
    a = 5.
    b = 6.
    
    function_grad = jax.grad(function)
    result = function_grad(a, b)
    
    print("result = ", result)
    print("--------------------------")
    
    result, result_dot = jax.jvp(function, (a, b), (a, b))
    
    print(f"result = {result}, result_dot = {result_dot}")
    print("--------------------------")
    
if __name__ == "__main__":
    
    main()

此外,defjvp还支持匿名方法,代码如下,


import jax

@jax.custom_jvp
def function(x, y):
    
    return x ** 3 + y ** 2

@function.defjvp
def function_jvp(primals, tangents):
    
    x, y = primals
    x_dot, y_dot = tangents
    
    primal_output = function(x, y)
    tangent_output = x * y_dot + y * x_dot
    
    return primal_output, tangent_output

def main():
    
    a = 5.
    b = 6.
    
    function_grad = jax.grad(function)
    result = function_grad(a, b)
    
    print("result = ", result)
    print("--------------------------")
    
    result, result_dot = jax.jvp(function, (a, b), (a, b))
    
    print(f"result = {result}, result_dot = {result_dot}")
    print("--------------------------")
    
    function_jvp_lambda = lambda primals, tangents: (primals[0], tangents[0])
    
    function.defjvp(function_jvp_lambda)
    function_grad = jax.grad(function)
    result = function_grad(a, b)
    
    print(f"result = {result}")
    print("--------------------------")
    
if __name__ == "__main__":
    
    main()

运行结果输出如下,


result =  6.0
--------------------------
result = 161.0, result_dot = 60.0
--------------------------
result = 1.0

请注意,此处tangents仅仅是为了满足jax.defjvp函数签名 ,实际并未参与运算。defjvp函数调用function来计算原函数输出。在高微分数的上下文中,每个微分变换的应用都将使用自定义的jvp规则,但规则调用原函数funciton时来计算原函数输出。

另外,在defjvp同样支持一些控制语句,比如,


import jax

@jax.custom_jvp

def function(x):
    
    return 2 * x + 3

@function.defjvp
def function_jvp(primals, tangents):
    
    x, = primals
    x_dot, = tangents
    
    if x >= 0:
        return function(x), x * x_dot
    else:
        return function(x), -x * x_dot
    
def main():
    
    grad_function = jax.grad(function)
    
    result = grad_function(1.0)
    
    print("result = ", result)
    print("------------------------")
    
    result = grad_function(-1.0)
    print("result = ", result)
    
if __name__ == "__main__":
    
    main()

运行结果打印输出如下,


result =  1.0
------------------------
result =  1.0

使用jax.custom_vjp

虽然jax.custom_jvp可以控制JAX中自定义函数的前向计算以及反向求导的计算规则,当在某些情况下,可能还是希望直接控制VJP的规则,即使用jax.custom_vjp来实现这一要求。代码如下所示,


@jax.custom_vjp
def function(x):
    
    return 2 * x ** 2 + 3 * x

def function_forward(x):
    
    return function(x), 4 * x + 3

def function_backward(dot_x, y_bar):
    
    return (dot_x * y_bar,)

def test():
    
    function.defvjp(function_forward, function_backward)
    
    function_grad = jax.grad(function)
    
    result = function_grad(3.0)
    
    print("result = ", result)
    
def main():
    
    test()
    
if __name__ == "__main__":
    
    main()

函数fucntion_forward是对正向求导的自定义,其返回值不仅仅是自定义的原始函数,还包含了手动计算后自定义函数的求导结果。

运行结果打印输出如下,


result =  15.0

结论

本章通过多个示例来说明jax.lax里原语的概念,以及如何编写符合原语规则的函数。另外,也介绍了jax.custom_jvp和jax.custom_vjp的详细使用方法。

内容较多,需要上手熟悉。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容