Introduction
由于最近项目需要,研究和学习PyTorch PTQ和QAT 量化的使用。比较新的PyTorch版本目前推荐使用FX Graph Mode Quantization。
FX Graph 模式量化Demo演示使用
Post-Training-Quantization (PTQ) 静态量化的主要流程:
PyTorch FX Graph模式进行量化的主要流程 step1 ~ step4:
- step1: 设置,选择量化方式 : 比如逐通道/layer QScheme, 量化之后的值域表示范围(Qmin, Qmax)
- step2: prepare_fx:
* a) 将输入的模型(nn.Module)转为GraphModule (IR转换)
* b) Graph子图,op融合(比如conv+relu --> convReLu)
* c) 在Conv, Linear等OP前后插入Observer, 用于收集激活值Feature map的特征(范围) - step3: 喂数据,进行Activation标定
- step4: 计算Weight和Activation量化参数 (比如 scale, zero_point), 模型FP32 --> INT8
- step5: 验证INT8 量化之后模型的精度
from ctypes import util
from torchvision.models import resnet18, resnet50
import torch
from torch.ao.quantization import quantize_fx, get_default_qconfig
import os
import copy
import utils
def calibrate(model, data_loader, num_batch, device):
utils.evaluate(model=model, data_loader=data_loader, neval_batches=num_batch, n_print=1, device=device)
if __name__ == '__main__':
device = torch.device('cuda', 0)
eval_batch_size = 32
imagenet_data='/media/wei/Document/ImageNet/ILSVRC2012'
model_fp = resnet50(pretrained=True, progress=True).to(device)
model_fp.eval()
_, test_dataloader = utils.prepare_dataloader(data_path=imagenet_data, eval_batch_size=eval_batch_size, num_workers=8)
utils.evaluate(model=model_fp, criterion=None, data_loader=test_dataloader, device=device)
# ResNet-18: Tested on imagenet-val: batch:3125 Acc@1 56.25 ( 69.76), Acc@5 75.00 ( 89.08)
# ResNet-50: batch:1560 Acc@1 59.38 ( 76.18), Acc@5 90.62 ( 92.87)
# torch quantization
model_prepare = copy.deepcopy(model_fp)
model_prepare.eval()
# 设置量化方式
qconfig_dict = {"", get_default_qconfig(backend='fbgemm')}
model_prepare = quantize_fx.prepare_fx(model=model_prepare, qconfig_dict=qconfig_dict)
model_prepare.eval()
# 标定,确定Activation的量化范围
calibrate(model_prepare, test_dataloader, 10, device)
# 根据之前设置的量化方式以及标定计算的参数, 进行模型转换, FP32--> INT8
quantized_model = quantize_fx.convert_fx(graph_module=model_prepare)
quantized_model.eval()
# 测试量化之后模型的精度
utils.evaluate(quantized_model, data_loader=test_dataloader)
得益于PyTorch FX Graph Quantization API的精简设计, 我们只需要很少的代码以及修改就可以实现量化, 激动!!!, 接下来我们一探FX Graph 量化背后的具体实现原理。
下面逐一分析FX Graph 量化的过程
PyTorch FX Graph量化——Step1. 量化方式的配置选择
这里是pytorch默认的PTQ量化配置, 'fbgemm' --- 这是一个矩阵计算的库,支持server 端x86 CPU 的 Int8 Conv, Linear等OP。
qconfig_dict = {"", get_default_qconfig(backend='fbgemm')}
def get_default_qconfig(backend='fbgemm'):
"""
Returns the default PTQ qconfig for the specified backend.
Args:
* `backend`: a string representing the target backend. Currently supports `fbgemm`
and `qnnpack`.
Return:
qconfig
"""
if backend == 'fbgemm':
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
weight=default_per_channel_weight_observer)
elif backend == 'qnnpack':
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
weight=default_weight_observer)
else:
qconfig = default_qconfig
return qconfig
default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric
)
我们发现qconfig包含2部分: 分别对weight, 以及activation的量化方式的配置, 其中 activation采用 HistogramObserver
基于直方图统计的逐tensor/layer非对称量化方式, Weight采用PerChannelMinMaxObserver
逐channel对称量化方式。
Why ? 为什么Activation和Weight的量化方式不同?
- Weight的量化方式:
- weight中元素的分布和activation有所不同: 因为weight一般都是均值为0, 左右对称的Gaussian分布, 因此采用对称量化
- 为了减少量化OP中的计算量, 因为对称量化的zero_point=0
参考高通AI的量化白皮书介绍:
Observer的作用
总的来说Observer是用于观测数据分布, 计算量化参数 scale, zero_point. 接下来从代码进行解析.
分析 PerChannelMinxMaxObserver
类
class PerChannelMinMaxObserver(_ObserverBase):
r"""Observer module for computing the quantization parameters based on the
running per channel min and max values.
This observer uses the tensor min/max statistics to compute the per channel
quantization parameters. The module records the running minimum and maximum
of incoming tensors, and uses this statistic to compute the quantization
parameters.
Args:
ch_axis: Channel axis
dtype: Quantized data type
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
memoryless: Boolean that controls whether observer removes old data when a new input is seen.
This is most useful for simulating dynamic quantization, especially during QAT.
The quantization parameters are computed the same way as in
:class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference
that the running min/max values are stored per channel.
Scales and zero points are thus computed per channel as well.
.. note:: If the running minimum equals to the running maximum, the scales
and zero_points are set to 1.0 and 0.
"""
min_val: torch.Tensor
max_val: torch.Tensor
def __init__(
self,
ch_axis=0,
dtype=torch.quint8,
qscheme=torch.per_channel_affine,
reduce_range=False,
quant_min=None,
quant_max=None,
factory_kwargs=None,
memoryless=False,
) -> None:
super(PerChannelMinMaxObserver, self).__init__(
dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs,
)
self.memoryless = memoryless
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
self.ch_axis = ch_axis
self.register_buffer("min_val", torch.tensor([], **factory_kwargs))
self.register_buffer("max_val", torch.tensor([], **factory_kwargs))
if (
self.qscheme == torch.per_channel_symmetric
and self.reduce_range
and self.dtype == torch.quint8
):
raise NotImplementedError(
"Cannot reduce range for symmetric quantization for quint8"
)
def forward(self, x_orig):
return self._forward(x_orig)
def _forward(self, x_orig):
if x_orig.numel() == 0:
return x_orig
x = x_orig.detach() # avoid keeping autograd tape
min_val = self.min_val
max_val = self.max_val
x_dim = x.size()
new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
new_axis_list[self.ch_axis] = 0
new_axis_list[0] = self.ch_axis
y = x.permute(new_axis_list)
# Need to match dtype of min/max because the updates to buffers
# are done in place and types need to match for comparisons
y = y.to(self.min_val.dtype)
y = torch.flatten(y, start_dim=1)
if min_val.numel() == 0 or max_val.numel() == 0 or self.memoryless:
min_val, max_val = torch.aminmax(y, dim=1)
else:
min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
min_val = torch.min(min_val_cur, min_val)
max_val = torch.max(max_val_cur, max_val)
self.min_val.resize_(min_val.shape)
self.max_val.resize_(max_val.shape)
self.min_val.copy_(min_val)
self.max_val.copy_(max_val)
return x_orig
@torch.jit.export
def calculate_qparams(self):
return self._calculate_qparams(self.min_val, self.max_val)
为了计算量化所需的参数, PyTorch定义了一系列的Observer, 比如MinMaxObserver
, MovingAveragingMinMaxObserver
等等, 所有这些XXXObserver都继承自一个基类,在基类的Observer中主要定义了以下2个重要的函数:
我们发现Observer中主要的2个函数:
- forward(self, x_orig): 观测weight中元素的最大,最小值
- calculate_qparams(self): 计算scale, zero_point
forward(self, x_orig)
函数的功能实现:
- 输入: x_orig: 也就是weight tensor, 一般CNN的weight的shape为: Oc * Ic * Kh * Kw 4D Tensor
- 输出/结果: 观测到的最大,最小值
在实例化Observer对象的时候, init() 函数中的一个参数 ch_axis=0
用于指定channel维度, ch_axis=0说明Observer观测的是weight的 Oc (output_channels) 方向的最大和最小值。 观测最大、最小值的核心代码:
min_val, max_val = torch.aminmax(y, dim=1)
因为Oc的在axis=0的维度上, 因此aminmax(dim=1)对axis=1的维度上进行了规约reduction, 得到了Oc个 min, max_val, 即Weight的每个output_channel包含一个scale, zero_point
def _forward(self, x_orig):
if x_orig.numel() == 0:
return x_orig
x = x_orig.detach() # avoid keeping autograd tape
min_val = self.min_val
max_val = self.max_val
x_dim = x.size()
new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
new_axis_list[self.ch_axis] = 0
new_axis_list[0] = self.ch_axis
y = x.permute(new_axis_list)
# Need to match dtype of min/max because the updates to buffers
# are done in place and types need to match for comparisons
y = y.to(self.min_val.dtype)
y = torch.flatten(y, start_dim=1)
if min_val.numel() == 0 or max_val.numel() == 0 or self.memoryless:
min_val, max_val = torch.aminmax(y, dim=1)
else:
min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
min_val = torch.min(min_val_cur, min_val)
max_val = torch.max(max_val_cur, max_val)
self.min_val.resize_(min_val.shape)
self.max_val.resize_(max_val.shape)
self.min_val.copy_(min_val)
self.max_val.copy_(max_val)
return x_orig
calculate_qparams
函数的功能实现
很容易理解这个函数是用于计算量化参数: scale & zero_point (对于线性量化)的, 下面分析代码实现:
- 输入: 观测得到的 max_val, min_val, 以及定义好的qmax, qmin
- 输出: 计算得到的scale, zero_point
def _calculate_qparams(
self, min_val: torch.Tensor, max_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Calculates the quantization parameters, given min and max
value tensors. Works for both per tensor and per channel cases
Args:
min_val: Minimum values per channel
max_val: Maximum values per channel
Returns:
scales: Scales tensor of shape (#channels,)
zero_points: Zero points tensor of shape (#channels,)
"""
if not check_min_max_valid(min_val, max_val):
return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)
quant_min, quant_max = self.quant_min, self.quant_max
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device
scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
if (
self.qscheme == torch.per_tensor_symmetric
or self.qscheme == torch.per_channel_symmetric
):
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
scale = torch.max(scale, self.eps)
if self.dtype == torch.quint8:
if self.has_customized_qrange:
# When customized quantization range is used, down-rounded midpoint of the range is chosen.
zero_point = zero_point.new_full(
zero_point.size(), (quant_min + quant_max) // 2
)
else:
zero_point = zero_point.new_full(zero_point.size(), 128)
elif self.qscheme == torch.per_channel_affine_float_qparams:
scale = (max_val - min_val) / float(quant_max - quant_min)
scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
# We use the quantize function
# xq = Round(Xf * inv_scale + zero_point),
# setting zero_point to (-1 * min *inv_scale) we get
# Xq = Round((Xf - min) * inv_scale)
zero_point = -1 * min_val / scale
else:
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.max(scale, self.eps)
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
# For scalar values, cast them to Tensors of size 1 to keep the shape
# consistent with default values in FakeQuantize.
if len(scale.shape) == 0:
# TODO: switch to scale.item() after adding JIT support
scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
if len(zero_point.shape) == 0:
# TODO: switch to zero_point.item() after adding JIT support
zero_point = torch.tensor(
[int(zero_point)], dtype=zero_point.dtype, device=device
)
if self.qscheme == torch.per_channel_affine_float_qparams:
zero_point = torch.tensor(
[float(zero_point)], dtype=zero_point.dtype, device=device
)
return scale, zero_point
计算量化参数Scale , zero_point的核心代码
- 对称量化 (symmetric Quantization)
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
- 非对称量化 (Affine Quantization)
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.max(scale, self.eps)
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
以上分析了nn.Conv2d layer的weight的量化参数的计算过程以及PerChannelMinMaxObserver
的实现过程。下面继续分析Activation的量化参数计算过程。
Activation的量化参数计算以及HistgramObserver分析
在选择量化设置的时候, 默认的backend=fbgemm中Activation采用 HistogramObserver
, 即基于直方图分析的方式计算量化参数。
if backend == 'fbgemm':
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
weight=default_per_channel_weight_observer)
HistogramObserver
过程分析
- 初始化: init()
- 默认bins=2048, 因为进行直方图统计需要设置一个bins代表直方图的统计区间,即把min_val到max_val区间划分2048份。
- qscheme=per_tensor_affine, 即量化粒度采用逐tensor/layer 仿射量化, 逐tensor代表只有一个量化参数scale + zero_point, 而不是一组
class HistogramObserver(_ObserverBase):
r"""
The module records the running histogram of tensor values along with
min/max values. ``calculate_qparams`` will calculate scale and zero_point.
Args:
bins: Number of bins to use for the histogram
upsample_rate: Factor by which the histograms are upsampled, this is
used to interpolate histograms with varying ranges across observations
dtype: Quantized data type
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
The scale and zero point are computed as follows:
1. Create the histogram of the incoming inputs.
The histogram is computed continuously, and the ranges per bin change
with every new tensor observed.
2. Search the distribution in the histogram for optimal min/max values.
The search for the min/max values ensures the minimization of the
quantization error with respect to the floating point model.
3. Compute the scale and zero point the same way as in the
:class:`~torch.ao.quantization.MinMaxObserver`
"""
histogram: torch.Tensor
min_val: torch.Tensor
max_val: torch.Tensor
def __init__(
self,
bins: int = 2048,
upsample_rate: int = 128,
dtype: torch.dtype = torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=False,
quant_min=None,
quant_max=None,
factory_kwargs=None,
) -> None:
# bins: The number of bins used for histogram calculation.
super(HistogramObserver, self).__init__(
dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs,
)
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
self.bins = bins
self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs))
self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits
self.upsample_rate = upsample_rate
- 对Activation的 Tensor进行统计观察
def forward(self, x_orig: torch.Tensor) -> torch.Tensor:
if x_orig.numel() == 0:
return x_orig
x = x_orig.detach()
min_val = self.min_val
max_val = self.max_val
same_values = min_val.item() == max_val.item()
is_uninitialized = min_val == float("inf") and max_val == float("-inf")
if is_uninitialized or same_values:
min_val, max_val = torch.aminmax(x)
self.min_val.resize_(min_val.shape)
self.min_val.copy_(min_val)
self.max_val.resize_(max_val.shape)
self.max_val.copy_(max_val)
assert (
min_val.numel() == 1 and max_val.numel() == 1
), "histogram min/max values must be scalar."
torch.histc(
x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram
)
else:
new_min, new_max = torch.aminmax(x)
combined_min = torch.min(new_min, min_val)
combined_max = torch.max(new_max, max_val)
# combine the existing histogram and new histogram into 1 histogram
# We do this by first upsampling the histogram to a dense grid
# and then downsampling the histogram efficiently
(
combined_min,
combined_max,
downsample_rate,
start_idx,
) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
assert (
combined_min.numel() == 1 and combined_max.numel() == 1
), "histogram min/max values must be scalar."
combined_histogram = torch.histc(
x, self.bins, min=int(combined_min), max=int(combined_max)
)
if combined_min == min_val and combined_max == max_val:
combined_histogram += self.histogram
else:
combined_histogram = self._combine_histograms(
combined_histogram,
self.histogram,
self.upsample_rate,
downsample_rate,
start_idx,
self.bins,
)
self.histogram.detach_().resize_(combined_histogram.shape)
self.histogram.copy_(combined_histogram)
self.min_val.detach_().resize_(combined_min.shape)
self.min_val.copy_(combined_min)
self.max_val.detach_().resize_(combined_max.shape)
self.max_val.copy_(combined_max)
return x_orig