Post training 4-bit quantization of convolutional networks for rapid-deployment

一、摘要

  • 介绍了三种方法,用于CNN模型的超低比特量化(4bits)和比特数自动选择。
  • Analytical Clipping for Integer Quantization(ACIQ),一种阶段阈值选择方法。
  • Per-channel bit allocation,一种对feature map各个channel实现不同比特量化的方法
  • bias-correction,一种偏移修正方法, 用以提高量化后的精度

二 Analytical Clipping for Integer Quantization (ACIQ)

ACIQ是一种量化阈值选择方法。对于后量化而言,最直接的方法是等价量化,不做截取,但这样损失较为严重(原始连续分布长尾太大)。通常需要找到一个阈值T用来截断,不在区间[-T, T]的值截取为-T或T。这样可以有效提高后量化精度。所以,问题就转变为,如何选择较好的截取值T,传统方法用KL散度,遍历可能的截取值不断计算量化前后分布的KL散度然后选取KL散度最小的T值作为截取值。论文中提出的ACIQ即一种基于优化思想的阈值T选取方法。
该方法用于激活值的量化。

首先,对于一个tensor(feature map),ACIQ假设该tensor的分布服从两种可能:拉普拉斯分布或高斯分布。量化过程就是将服从该分布的tensor中的值量化到2^M离散区间中。其中M表示比特数。
ACIQ定义原始浮点分布的密度函数f(x)
,截取值\alpha 以及量化函数Q(x)
。所以,量化前和量化后的L2 loss就等于:

image.png

而整个量化问题就被转变为:求解\alpha
,使得上述的loss值最小。

从上述表达式不难看出,量化损失一共分为三段:负无穷到-\alpha截断产生的误差, -\alpha\alpha
之间的round量化误差,以及\alpha 到正无穷的截断误差。论文用可导函数来表示各个阶段的误差进而方便求解。论文正文里以tensor服从拉普拉斯分布的情况进行推导。
量化误差如下:

image.png

截断误差如下:
image.png

所以,最终的整体量化损失如下:
image.png

此时,量化函数被成功的转换成了一个可以求导的连续函数,只需要对其求偏导,就可以得到使量化误差最小的截断值:
image.png

其中, \alpha 为截取值, \beta
为拉普拉斯分布的参数。M为量化后的比特数。最后,求解公式在M = 2,3,4时,\alpha=T\beta ,T分别取值2.83, 3.89, 5.03。

上述即是ACIQ的核心原理,利用优化的思路来求解量化过程截断值进而最小化量化损失。注意ACIQ有一个较强的先验假设,即tensor的数据分布要符合拉普拉斯分布或高斯分布(高斯分布的截取值计算在论文的附页中)。
https://github.com/submission2019/cnn-quantization


        print("=> using pre-trained model '{}'".format(args.arch))
        if args.arch == 'shufflenet':
            import models.ShuffleNet as shufflenet
            self.model = shufflenet.ShuffleNet(groups=8)
            params = torch.load('ShuffleNet_1g8_Top1_67.408_Top5_87.258.pth.tar')
            self.model = torch.nn.DataParallel(self.model, args.device_ids)
            self.model.load_state_dict(params)
        else:
            self.model = models.__dict__[args.arch](pretrained=True)

        set_node_names(self.model)
        # Mark layers before relue for fusing
        if 'resnet' in args.arch:
            resnet_mark_before_relu(self.model)

        # BatchNorm folding
        if 'resnet' in args.arch or args.arch == 'vgg16_bn' or args.arch == 'inception_v3':
            print("Perform BN folding")
            search_absorbe_bn(self.model)
            QM().bn_folding = Trueself.model.to(args.device)
        QM().quantize_model(self.model)

        if args.device_ids and len(args.device_ids) > 1 and args.arch != 'shufflenet' and args.arch != 'mobilenetv2':
            if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
                self.model.features = torch.nn.DataParallel(self.model.features, args.device_ids)
            else:
                self.model = torch.nn.DataParallel(self.model, args.device_ids)

        # define loss function (criterion) and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.criterion.to(args.device)

        cudnn.benchmark = True
        # Data loading code
        valdir = os.path.join(args.data, 'val')

        if args.arch not in models.__dict__ and args.arch in pretrainedmodels.model_names:
            dataparallel = args.device_ids is not None and len(args.device_ids) > 1
            tfs = [mutils.TransformImage(self.model.module if dataparallel else self.model)]
        else:
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
            resize = 256 if args.arch != 'inception_v3' else 299
            crop_size = 224 if args.arch != 'inception_v3' else 299
            tfs = [
                transforms.Resize(resize),
                transforms.CenterCrop(crop_size),
                transforms.ToTensor(),
                normalize,
            ]

        self.val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(valdir, transforms.Compose(tfs)),
            batch_size=args.batch_size, shuffle=(True if (args.kld_threshold or args.aciq_cal or args.shuffle) else False),
            num_workers=args.workers, pin_memory=True)  


def run(self):
        if args.eval_precision:
            elog = EvalLog(['dtype', 'val_prec1', 'val_prec5'])
            print("\nFloat32 no quantization")
            QM().disable()
            val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion)
            elog.log('fp32', val_prec1, val_prec5)
            logging.info('\nValidation Loss {val_loss:.4f} \t'
                         'Validation Prec@1 {val_prec1:.3f} \t'
                         'Validation Prec@5 {val_prec5:.3f} \n'
                         .format(val_loss=val_loss, val_prec1=val_prec1, val_prec5=val_prec5))
            print("--------------------------------------------------------------------------")

            for q in [8, 7, 6, 5, 4]:
                args.qtype = 'int{}'.format(q)
                print("\nQuantize to %s" % args.qtype)
                QM().quantize = True
                QM().reload(args, get_params())
                val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion)
                elog.log(args.qtype, val_prec1, val_prec5)
                logging.info('\nValidation Loss {val_loss:.4f} \t'
                             'Validation Prec@1 {val_prec1:.3f} \t'
                             'Validation Prec@5 {val_prec5:.3f} \n'
                             .format(val_loss=val_loss, val_prec1=val_prec1, val_prec5=val_prec5))
                print("--------------------------------------------------------------------------")
            print(elog)
            elog.save('results/precision/%s_%s_clipping.csv' % (args.arch, args.threshold))
        elif args.custom_test:
            log_name = 'results/custom_test/%s_max_mse_%s_cliping_layer_selection.csv' % (args.arch, args.threshold)
            elog = EvalLog(['num_8bit_layers', 'indexes', 'val_prec1', 'val_prec5'], log_name, auto_save=True)
            for i in range(len(max_mse_order_id)+1):
                _8bit_layers = ['conv0_activation'] + max_mse_order_id[0:i]
                print("it: %d, 8 bit layers: %d" % (i, len(_8bit_layers)))
                QM().set_8bit_list(_8bit_layers)
                val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion)
                elog.log(i+1, str(_8bit_layers), val_prec1, val_prec5)
            print(elog)
        else:
            val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion)
            if self.ml_logger is not None and self.ml_logger.mlflow.active_run() is not None:
                self.ml_logger.mlflow.log_metric('top1', val_prec1)
                self.ml_logger.mlflow.log_metric('top5', val_prec5)
                self.ml_logger.mlflow.log_metric('loss', val_loss)

            return val_loss, val_prec1, val_prec5

三、Per-channel bit-allocation

Per-channel bit-allocation核心思想是允许一个tensor中的各个channel的量化bits不相同(channel1可能用4bits量化;channel2可能用5bits量化,channel3可能用3bits量化),并找到每个channel的最佳量化bits。同时要求平均每个channel的量化bits值为4。

首先,该方法借用了ACIQ中对连续密度函数和量化后的离散分布之间的L2误差的定义,并在该定义的基础上,

  • 引入各个channel的量化比特数作为限定条件。
  • 并引入拉格朗日乘子

得到channel变化bits时的量化损失表达式。


image.png

其中,

  • 第一项为ACIQ中的量化loss表达式,
  • 第二项表示拉格朗日乘子引入的约束损失。
  • M_i表示第i个channel的量化比特,B表示所有channel的量化间隔总和。

对拉格朗日表达式求偏导,得:


image.png

最终,我们可以得到**各个channel的最佳量化bit和原始浮点数据分布的关系如下:最终,我们可以得到各个channel的最佳量化bit和原始浮点数据分布的关系如下:


image.png

总体来说,该方法延续了ACIQ的优化方法求解量化问题思想,从算法角度来看可以在有限的bits量化时通过灵活调整各个channel的量化比特数,达到量化损失最小的情况。但在实际应用中,如此各个channel分比特量化必须要配合非常特殊的硬件加速实现,实际应用价值值得商榷。

四、Bias-Correction

Bias-Correction方法主要用于对weight的量化,作者观察到量化前后权重分布的均值和方差纯在固有的偏差,该方法即通过一种简单的方法补偿weight量化前后偏移的mean和var。

image.png

五、实验

作者选用了常见的几种分类模型,进行了组合/消融实验。同时也分别进行了Weight和feature map用不同比特数量化的实验结果对比:


image.png
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,001评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,210评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,874评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,001评论 1 291
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,022评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,005评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,929评论 3 416
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,742评论 0 271
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,193评论 1 309
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,427评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,583评论 1 346
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,305评论 5 342
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,911评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,564评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,731评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,581评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,478评论 2 352

推荐阅读更多精彩内容