MegEngine Python 层模块串讲(下)

前面的文章中,我们简单介绍了在 MegEngine imperative 中的各模块以及它们的作用。对于新用户而言可能不太了解各个模块的使用方法,对于模块的结构和原理也是一头雾水。Python 作为现在深度学习领域的主流编程语言,其相关的模块自然也是深度学习框架的重中之重。

模块串讲将对 MegEngine 的 Python 层相关模块分别进行更加深入的介绍,会涉及到一些原理的解释和代码解读。Python 层模块串讲共分为上、中、下三个部分,本文将介绍 Python 层的 quantization 模块。量化是为了减少模型的存储空间和计算量,从而加速模型的推理过程。在量化中,我们将权重和激活值从浮点数转换为整数,从而减少模型的大小和运算的复杂性。通过本文读者将会对量化的基本原理和使用 MegEngine 得到量化模型有所了解。

降低模型内存占用利器 —— quantization 模块

量化是一种对深度学习模型参数进行压缩以降低计算量的技术。它基于这样一种思想:神经网络是一个近似计算过程,不需要其中每个计算过程的绝对的精确。因此在某些情况下可以把需要较多比特存储的模型参数转为使用较少比特存储,而不影响模型的精度。

量化通过舍弃数值表示上的精度来追求极致的推理速度。直觉上用低精度/比特类型的模型参数会带来较大的模型精度下降(称之为掉点),但在经过一系列精妙的量化处理之后,掉点可以变得微乎其微。

如下图所示,量化通常是将浮点模型(常见神经网络的 Tensor 数据类型一般是 float32)处理为一个量化模型(Tensor 数据类型为 int8 等)。

1.png

量化基本流程

MegEngine 中支持工业界的两类主流量化技术,分别是训练后量化(PTQ)和量化感知训练(QAT)。

  1. 训练后量化(Post-Training Quantization, PTQ

    训练后量化,顾名思义就是将训练后的 Float 模型转换成低精度/比特模型。

    比较常见的做法是对模型的权重(weight)和激活值(activation)进行处理,把它们转换成精度更低的类型。虽然是在训练后再进行精度转换,但为了获取到模型转换需要的一些统计信息(比如缩放因子 scale),仍然需要在模型进行前向计算时插入观察者(Observer)。

    使用训练后量化技术通常会导致模型掉点,某些情况下甚至会导致模型不可用。可以使用小批量数据在量化之前对 Observer 进行校准(Calibration),这种方案叫做 Calibration 后量化。也可以使用 QAT 方案。

  2. 量化感知训练(Quantization-Aware Training, QAT

    QAT 会向 Float 模型中插入一些伪量化(FakeQuantize)算子,在前向计算过程中伪量化算子根据 Observer 观察到的信息进行量化模拟,模拟数值截断的情况下的数值转换,再将转换后的值还原为原类型。让被量化对象在训练时“提前适应”量化操作,减少训练后量化的掉点影响。

    而增加这些伪量化算子模拟量化过程又会增加训练开销,因此模型量化通常的思路是:

    • 按照平时训练模型的流程,设计好 Float 模型并进行训练,得到一个预训练模型;
    • 插入 ObserverFakeQuantize 算子,得到 Quantized-Float 模型(QFloat 模型)进行量化感知训练;
    • 训练后量化,得到真正的 Quantized 模型(Q 模型),也就是最终用来进行推理的低比特模型。

    过程如下图所示(实际使用时,量化流程也可能会有变化):

2.png
  1. 注意这里的量化感知训练 QAT 是在预训练好的 QFloat 模型上微调(Fine-tune)的(而不是在原来的 Float 模型上),这样减小了训练的开销,得到的微调后的模型再做训练后量化 PTQ(“真量化”),QModel 就是最终部署的模型。

模型(Model)与模块(Module

量化是一个对模型(Model)的转换操作,但其本质其实是对模型中的模块( Module) 进行替换。

MegEngine 中,对应与 Float ModelQFloat ModelQ ModelModule 分别为:

  1. 进行正常 float 运算的默认 Module
  2. 带有 ObserverFakeQuantize 算子的 qat.QATModule
  3. 无法训练、专门用于部署的 quantized.QuantizedModule

Conv 算子为例,这些 Module 对应的实现分别在:

量化配置 QConfig

量化配置包括 ObserverFakeQuantize 两部分,要设置它们,用户可以使用 MegEngine 预设配置也可以自定义配置。

  1. 使用 MegEngine 预设配置

    MegEngine 提供了多种量化预设配置

    ema_fakequant_qconfig 为例,用户可以通过如下代码使用该预设配置:

import megengine.quantization as Q
Q.quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)
  1. 用户自定义量化配置

    用户还可以自己选择 ObserverFakeQuantize,灵活配置 QConfig 灵活选择 weight_observeract_observerweight_fake_quantact_fake_quant)。

    可选的 ObserverFakeQuantize 可参考量化 API 参考页面。

QConfig 提供了一系列用于对模型做量化的接口,要使用这些接口,需要网络的 Module 能够在 forward 时给权重、激活值加上 Observer 以及进行 FakeQuantize

模型转换的作用是:将普通的 Float Module 替换为支持这些操作的 QATModule(可以训练),再替换为 QuantizeModule(无法训练、专用于部署)。

Conv2d 为例,模型转换的过程如图:

3.png

在量化时常常会用到算子融合(Fusion)。比如一个 Conv2d 算子加上一个 BatchNorm2d 算子,可以用一个 ConvBn2d 算子来等价替代,这里 ConvBn2d 算子就是 Conv2dBatchNorm2d 的融合算子。

MegEngine 中提供了一些预先融合好的 Module,比如 ConvRelu2dConvBn2dConvBnRelu2d 等。使用融合算子会使用底层实现好的融合算子(kernel),而不会分别调用子模块在底层的 kernel,因此能够加快模型的速度,而且框架还无需根据网络结构进行自动匹配和融合优化,同时存在融合和不需融合的算子也可以让用户能更好的控制网络转换的过程。

实现预先融合的 Module 也有缺点,那就是用户需要在代码中修改原先的网络结构(把可以融合的多个 Module 改为融合后的 Module)。

模型转换的原理是,将父 Module 中的 Quantable (可被量化的)子 Module 替换为新 Module。而这些 Quantable submodule 中可能又包含 Quantable submodule,这些 submodule 不会再进一步转换,因为其父 Module 被替换后的 forward 计算过程已经改变了,不再依赖于这些子 Module

有时候用户不希望对模型的部分 Module 进行转换,而是保留其 Float 状态(比如转换会导致模型掉点),则可以使用 disable_quantize 方法关闭量化。

比如下面这行代码关闭了 fc 层的量化处理:

model.fc.disable_quantize()

由于模型转换过程修改了原网络结构,因此模型保存与加载无法直接适用于转换后的网络,读取新网络保存的参数时,需要先调用转换接口得到转换后的网络,才能用 load_state_dict 将参数进行加载。

量化代码

要从一个 Float 模型得到一个可用于部署的量化模型,大致需要经历三个步骤:

  1. 修改网络结构。将 Float 模型中的普通 Module 替换为已经融合好的 Module,比如 ConvBn2dConvBnRelu2d 等(可以参考 imperative/python/megengine/module/quantized 目录下提供的已融合模块)。然后在正常模式下预训练模型,并且在每轮迭代保存网络检查点。

    ResNet18BasicBlock 为例,模块修改前的代码为:

class BasicBlock(M.Module):
      def __init__(self, in_channels, channels):
         super().__init__()
         self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=dilation, bias=False)
         self.bn1 = M.BatchNorm2d
         self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
         self.bn2 = M.BatchNorm2d
         self.downsample = (
            M.Identity()
            if in_channels == channels and stride == 1
            else M.Sequential(
            M.Conv2d(in_channels, channels, 1, stride, bias=False)
            M.BatchNorm2d
         )

      def forward(self, x):
         identity = x
         x = F.relu(self.bn1(self.conv1(x)))
         x = self.bn2(self.conv2(x))
         identity = self.downsample(identity)
         x = F.relu(x + identity)
         return x

注意到现在的前向中使用的都是普通 Module 拼接在一起,而实际上许多模块是可以融合的。

用可以融合的模块替换掉原先的 Module

class BasicBlock(M.Module):
      def __init__(self, in_channels, channels):
         super().__init__()
         self.conv_bn_relu1 = M.ConvBnRelu2d(in_channels, channels, 3, 1, padding=dilation, bias=False)
         self.conv_bn2 = M.ConvBn2d(channels, channels, 3, 1, padding=1, bias=False)
         self.downsample = (
            M.Identity()
            if in_channels == channels and stride == 1
            else M.ConvBn2d(in_channels, channels, 1, 1, bias=False)
         )
         self.add_relu = M.Elemwise("FUSE_ADD_RELU")

      def forward(self, x):
         identity = x
         x = self.conv_bn_relu1(x)
         x = self.conv_bn2(x)
         identity = self.downsample(identity)
         x = self.add_relu(x, identity)
         return x

注意到此时前向中已经有许多模块使用的是融合后的 Module

再对该模型进行若干论迭代训练,并保存检查点:

for step in range(0, total_steps):
    # Linear learning rate decay
    epoch = step // steps_per_epoch
    learning_rate = adjust_learning_rate(step, epoch)

    image, label = next(train_queue)
    image = tensor(image.astype("float32"))
    label = tensor(label.astype("int32"))

    n = image.shape[0]

    loss, acc1, acc5 = train_func(image, label, net, gm)  # traced
    optimizer.step().clear_grad()

    # Save checkpoints

完整代码见:

-   [修改前的模型结构](https://github.com/MegEngine/Models/blob/master/official/vision/classification/resnet/model.py)
-   [修改后的模型结构](https://github.com/MegEngine/Models/blob/master/official/quantization/models/resnet.py)
  1. 调用 quantize_qat 方法 将 Float 模型转换为 QFloat 模型,并进行微调(量化感知训练或校准,取决于 QConfig)。

    使用 quantize_qat 方法将 Float 模型转换为 QFloat 模型的代码大致为:

from megengine.quantization import ema_fakequant_qconfig, quantize_qat

model = ResNet18()

# QAT
quantize_qat(model, ema_fakequant_qconfig)

# Or Calibration:
# quantize_qat(model, calibration_qconfig)

Float 模型转换为 QFloat 模型后,加载预训练 Float 模型保存的检查点进行微调 / 校准:

if args.checkpoint:
    logger.info("Load pretrained weights from %s", args.checkpoint)
    ckpt = mge.load(args.checkpoint)
    ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
    model.load_state_dict(ckpt, strict=False)

# Fine-tune / Calibrate with new traced train_func
# Save checkpoints

完整代码见:

-   [Finetune](https://github.com/MegEngine/Models/blob/master/official/quantization/finetune.py)
-   [Calibration](https://github.com/MegEngine/Models/blob/master/official/quantization/calibration.py)
  1. 调用 quantize 方法将 QFloat 模型转换为 Q 模型,也就是可用于模型部署的量化模型。

需要在推理的方法中设置 tracecapture_as_const=True,以进行模型导出:

from megengine.quantization import quantize

@jit.trace(capture_as_const=True)
def infer_func(processed_img):
    model.eval()
    logits = model(processed_img)
    probs = F.softmax(logits)
    return probs

quantize(model)

processed_img = transform.apply(image)[np.newaxis, :]
processed_img = processed_img.astype("int8")
probs = infer_func(processed_img)

infer_func.dump(output_file, arg_names=["data"])

调用了 quantize 后,model 就从 QFloat 模型转换为了 Q 模型,之后便使用这个 Quantized 模型进行推理。

调用 dump 方法将模型导出,便得到了一个可用于部署的量化模型。

完整代码见:

小结

MegEngine Python 层模块串讲系列到这里就结束了,我们介绍了用户在使用 MegEngine 时主要会接触到的 python 层的各个模块的主要功能、结构以及使用方法,此外还有一些原理性的介绍。对于各模块具体实现感兴趣的读者可以参考 MegEngine 官方文档github。之后的文章我们会对 MegEngine 开发相关工具以及 MegEngine 底层的实现做更深入的介绍。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,816评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,729评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,300评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,780评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,890评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,084评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,151评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,912评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,355评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,666评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,809评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,504评论 4 334
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,150评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,882评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,121评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,628评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,724评论 2 351

推荐阅读更多精彩内容