numba 是一款可以将 Python 函数编译为机器代码的编译器。
以球放盒子问题为例,示例代码如下:
import numpy as np
from numba import jit
import time
num_experiments = 200000
num_balls = 21
num_boxes = 40
num_target_balls = 3
num_target_boxes = 0
@jit(nopython=True)
def get():
num_target_boxes = 0
for i in range(num_experiments):
boxes = np.zeros(num_boxes)
for j in range(num_balls):
box_index = np.random.randint(0, num_boxes)
boxes[box_index] += 1
num_target_boxes += np.count_nonzero(boxes == num_target_balls)
start_time = time.time()
num_target_boxes = get()
end_time = time.time()
print("Time taken: ", end_time - start_time)
expected_num_target_boxes = num_target_boxes / num_experiments
print(expected_num_target_boxes)
对比没有添加 @jit(nopython=True)
的代码,运行时间从 10.28s 缩短到了 1.29s,效果非常明显。
但是 numba 的 jit 本身会花费时间。如果被加速的函数没有包含像是 for 循环的语句,会起到反作用。
针对相同的概率问题,这次选用纯矩阵的方式进行运算:
import numpy as np
from numba import jit
import time
num_experiments = 200000
num_balls = 21
num_boxes = 40
num_target_balls = 3
num_target_boxes = 0
@jit(nopython=True)
def get():
boxes = np.random.randint(0, num_boxes, size=(num_experiments, num_balls))
num_target_boxes = np.sum(np.count_nonzero(boxes == num_target_balls, axis=1))
return num_target_boxes
start_time = time.time()
num_target_boxes = get()
end_time = time.time()
print("Time taken: ", end_time - start_time)
expected_num_target_boxes = num_target_boxes / num_experiments
print(expected_num_target_boxes)
对比没有添加 @jit(nopython=True)
的代码,运行时间从 0.067s 反而增加到了 2.03s。
结论:numba 能有效加速纯 Numpy 运算,但 numba 不是万能的。
要编写适用于 numbda.jit 的代码,目前发现要遵循以下原则:
- 这个函数不能使用 numpy 以外的库
- 这个函数没有出现除 python 整数、python 小数和 numpy 一系列类型以外的变量类型