混合精度训练介绍
Mixed-Precision Training是指在深度学习AI模型训练过程中不同的层Layer采用不同的数据精度进行训练, 最终使得训练过程中的资源消耗(GPU显存,GPU 算力)降低, 同时保证训练可收敛,模型精度与高精度FP32的结果接近。
CNN ResNet 混合精度训练
-
导入torch.cuda.amp package
由于CNN训练要求大量算力, 因此一般混合精度需要使用 NVIDIA Automatic Mixed Precision (AMP)包, NVIDIA的AMP以及集成到了Pyorch, 因此直接调用torch.cuda.amp
APIs.
混合精度主要用到 Loss-Scaling (损失缩放) + Auto-cast (自动精度选择/转换)
# Mixed-Precision Training
from torch.cuda.amp.grad_scaler import GradScaler
from torch.cuda.amp.autocast_mode import autocast
# 实例化一个GradeScaler对象
scaler = GradScaler()
- 对Training Loop进行修改, 修改2个地方
添加
autocast()
: autocast是一个Python context Manager, autocast 作用区域的代码在运行的时候会跟据OP的类型,自动转换为预定义好的低精度类型 (比如FP16)
*注意: autocast一般作用的代码区域为 Forward, Backward 阶段需要指定autocast, 因此在Forward阶段不同的layer (op)以及被设置了各自的精度模式, 在Backward阶段,采用和Forward相同的精度进行计算。添加
GradeScalar
: GradeScalar的目的是对权重的梯度矩阵值进行缩放(扩大化), 因为一般情况下的值非常小,如果采用低精度类型 (FP16),则导致下溢underflow. 解决方法之一就是希望将 进行缩放变大; 由于Loss函数导数具有线性性质, 因此也可以对 Loss进行缩放,实际等价于对梯度值进行了放大。
for i, (images, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
# move data to the same device as model
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# compute output
with autocast(enabled=args.mixed_precision, dtype=torch.float16):
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
# losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
# loss.backward()
# optimizer.step()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
...
为了便于观察训练过程, 在代码中添加了Pytorch Profiler进行可视化:
为了说明情况, 可以只跑少数的几个batch即可
对应的CUDA kernel函数, FP16类型的
采用BFLOAT16进行混合精度
方法很简单,在autocast
的dtype设置为torch.bfloat16
。 除此之外,需要采用支持BFLOAT16类型的计算设备(TPU, >=NVIDIA Ampere/Volta 架构的GPU, 比如NVIDIA V100, A100, RTX 30/40系列)
with autocast(enabled=args.mixed_precision, dtype=torch.bfloat16):
output = model(images)
loss = criterion(output, target)