PyTorch模型减枝技术-pruning

介绍

减枝(prune)是深度学习模型压缩常见的技术之一, 目的是使得CNN/RNN/Transformer等模型的权重weight参数稀疏化 sparsity,即weight包含大量的0元素.
模型稀疏化的优点:

  1. 存储优势: 如果模型weight包含大量的0元素,实际存储中可以采用各种压缩格式,比如COO, CSR
  2. 计算优势: 由于包含大量的0元素, 因此现代的很多加速器比如NPU都设计了跳零单元 zero skipping unit, 减少了计算开销

本节的主要目的是认识并掌握PyTorch中对pruning技术的应用, let's coding!


Requirement

如下环境测试:

  • Ubuntu 20.04
  • PyTorch 1.12

代码实现

模型定义

简单起见,采用LeNet-5

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
# https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

# %%
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16*5*5, 120)  # 5x5 dim
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = LeNet().to(device=device)

# inspect a module / layer
module = model.conv1
print(list(module.named_parameters()))
print(list(module.named_buffers()))

输出第一个Conv2d layer的learanable parameters, weight, bias, prune之前的参数:
输出结果


[('weight', Parameter containing:
tensor([[[[-0.1544,  0.0351,  0.2471],
          [-0.0788,  0.2216, -0.0925],
          [-0.1486, -0.1366, -0.0963]]],


        [[[ 0.2780, -0.1358,  0.2029],
          [ 0.2228,  0.0061, -0.1716],
          [-0.3228, -0.1036,  0.2223]]],


        [[[-0.2228,  0.0742, -0.1789],
          [-0.1888, -0.3132, -0.1999],
          [ 0.0359, -0.1263,  0.2270]]],


        [[[-0.2067, -0.2954, -0.1952],
          [-0.2652,  0.2705, -0.1056],
          [ 0.1010, -0.1888, -0.0087]]],


        [[[-0.1197, -0.0913, -0.2631],
          [-0.2442,  0.2834, -0.0278],
          [ 0.1842,  0.1579, -0.3101]]],


        [[[-0.2317,  0.1837,  0.1096],
          [-0.0636, -0.1924, -0.3029],
          [ 0.1714,  0.1079,  0.0050]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.2799,  0.2889,  0.1455,  0.0563, -0.0082,  0.1129], device='cuda:0',
       requires_grad=True))]
[]

对某一层Layer的Weight, bias Prune

代码如下, 对conv1 layer的weight, bias进行prune
prune.random_unstructured(module=module, name='weight', amount=0.3)

  • random_unstructure: prune 方法, 非结构化减枝, 这种算法简单,但是由于是非结构化,因此对硬件加速不是很友好.
  • name=weight, 代表对weight进行prune, 还可以是bias
  • amount: 减枝的程度, 如果是0~1之间的小数,例如0.3代表30%的weight参数进行减枝; 如果是整数, 例如10代表weight中10个元素减枝为0

# Pruning a Module
# Prune the first Conv layer

prune.random_unstructured(module=module, name='weight', amount=0.3)
# prune之后, 原始的weight被remove, 替换为 weight_orig(原始未prune的weight)
print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(module.weight)
# prune前后的weight shape没有变化, 但是prune之后的weight出现了大量的0元素
# prune对象
print(module._forward_pre_hooks)
prune.l1_unstructured(module=module, name='bias', amount=3)
print(list(module.named_parameters())) # bias_ori
print(list(module.named_buffers()))
print(module.bias)
print(module._forward_pre_hooks)

对多个Layer 进行Prune

例如对net中所有的Conv2d, Linear layer进行Prune, 直接遍历layers



---
# 多层prune
# Conv2d, Linear进行Prune
new_model = LeNet()

for name, module in new_model.named_modules():
    if isinstance(module, nn.Conv2d):
        prune.l1_unstructured(module=module, name='weight', amount=0.2)
    elif isinstance(module, nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())

Random_unstructed Prune
此函数是PyTorch中已经实现的prune方法之一, 非结构化随机减枝

def random_unstructured(module, name, amount):
    r"""Prunes tensor corresponding to parameter called ``name`` in ``module``
    by removing the specified ``amount`` of (currently unpruned) units
    selected at random.
    Modifies module in place (and also return the modified module) by:

    1) adding a named buffer called ``name+'_mask'`` corresponding to the
       binary mask applied to the parameter ``name`` by the pruning method.
    2) replacing the parameter ``name`` by its pruned version, while the
       original (unpruned) parameter is stored in a new parameter named
       ``name+'_orig'``.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (str): parameter name within ``module`` on which pruning
                will act.
        amount (int or float): quantity of parameters to prune.
            If ``float``, should be between 0.0 and 1.0 and represent the
            fraction of parameters to prune. If ``int``, it represents the
            absolute number of parameters to prune.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input module

    Examples:
        >>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1)
        >>> torch.sum(m.weight_mask == 0)
        tensor(1)

    """
    RandomUnstructured.apply(module, name, amount)
    return module
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 211,948评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,371评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,490评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,521评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,627评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,842评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,997评论 3 408
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,741评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,203评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,534评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,673评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,339评论 4 330
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,955评论 3 313
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,770评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,000评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,394评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,562评论 2 349

推荐阅读更多精彩内容