介绍
减枝(prune)是深度学习模型压缩常见的技术之一, 目的是使得CNN/RNN/Transformer等模型的权重weight参数稀疏化 sparsity,即weight包含大量的0元素.
模型稀疏化的优点:
- 存储优势: 如果模型weight包含大量的0元素,实际存储中可以采用各种压缩格式,比如COO, CSR
- 计算优势: 由于包含大量的0元素, 因此现代的很多加速器比如NPU都设计了跳零单元 zero skipping unit, 减少了计算开销
本节的主要目的是认识并掌握PyTorch中对pruning技术的应用, let's coding!
Requirement
如下环境测试:
- Ubuntu 20.04
- PyTorch 1.12
代码实现
模型定义
简单起见,采用LeNet-5
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
# https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
# %%
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16*5*5, 120) # 5x5 dim
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
# inspect a module / layer
module = model.conv1
print(list(module.named_parameters()))
print(list(module.named_buffers()))
输出第一个Conv2d layer的learanable parameters, weight, bias, prune之前的参数:
输出结果
[('weight', Parameter containing:
tensor([[[[-0.1544, 0.0351, 0.2471],
[-0.0788, 0.2216, -0.0925],
[-0.1486, -0.1366, -0.0963]]],
[[[ 0.2780, -0.1358, 0.2029],
[ 0.2228, 0.0061, -0.1716],
[-0.3228, -0.1036, 0.2223]]],
[[[-0.2228, 0.0742, -0.1789],
[-0.1888, -0.3132, -0.1999],
[ 0.0359, -0.1263, 0.2270]]],
[[[-0.2067, -0.2954, -0.1952],
[-0.2652, 0.2705, -0.1056],
[ 0.1010, -0.1888, -0.0087]]],
[[[-0.1197, -0.0913, -0.2631],
[-0.2442, 0.2834, -0.0278],
[ 0.1842, 0.1579, -0.3101]]],
[[[-0.2317, 0.1837, 0.1096],
[-0.0636, -0.1924, -0.3029],
[ 0.1714, 0.1079, 0.0050]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.2799, 0.2889, 0.1455, 0.0563, -0.0082, 0.1129], device='cuda:0',
requires_grad=True))]
[]
对某一层Layer的Weight, bias Prune
代码如下, 对conv1 layer的weight, bias进行prune
prune.random_unstructured(module=module, name='weight', amount=0.3)
-
random_unstructure
: prune 方法, 非结构化减枝, 这种算法简单,但是由于是非结构化,因此对硬件加速不是很友好. -
name=weight
, 代表对weight进行prune, 还可以是bias
-
amount
: 减枝的程度, 如果是0~1之间的小数,例如0.3代表30%的weight参数进行减枝; 如果是整数, 例如10代表weight中10个元素减枝为0
# Pruning a Module
# Prune the first Conv layer
prune.random_unstructured(module=module, name='weight', amount=0.3)
# prune之后, 原始的weight被remove, 替换为 weight_orig(原始未prune的weight)
print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(module.weight)
# prune前后的weight shape没有变化, 但是prune之后的weight出现了大量的0元素
# prune对象
print(module._forward_pre_hooks)
prune.l1_unstructured(module=module, name='bias', amount=3)
print(list(module.named_parameters())) # bias_ori
print(list(module.named_buffers()))
print(module.bias)
print(module._forward_pre_hooks)
对多个Layer 进行Prune
例如对net中所有的Conv2d, Linear layer进行Prune, 直接遍历layers
---
# 多层prune
# Conv2d, Linear进行Prune
new_model = LeNet()
for name, module in new_model.named_modules():
if isinstance(module, nn.Conv2d):
prune.l1_unstructured(module=module, name='weight', amount=0.2)
elif isinstance(module, nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys())
Random_unstructed Prune
此函数是PyTorch中已经实现的prune方法之一, 非结构化随机减枝
def random_unstructured(module, name, amount):
r"""Prunes tensor corresponding to parameter called ``name`` in ``module``
by removing the specified ``amount`` of (currently unpruned) units
selected at random.
Modifies module in place (and also return the modified module) by:
1) adding a named buffer called ``name+'_mask'`` corresponding to the
binary mask applied to the parameter ``name`` by the pruning method.
2) replacing the parameter ``name`` by its pruned version, while the
original (unpruned) parameter is stored in a new parameter named
``name+'_orig'``.
Args:
module (nn.Module): module containing the tensor to prune
name (str): parameter name within ``module`` on which pruning
will act.
amount (int or float): quantity of parameters to prune.
If ``float``, should be between 0.0 and 1.0 and represent the
fraction of parameters to prune. If ``int``, it represents the
absolute number of parameters to prune.
Returns:
module (nn.Module): modified (i.e. pruned) version of the input module
Examples:
>>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1)
>>> torch.sum(m.weight_mask == 0)
tensor(1)
"""
RandomUnstructured.apply(module, name, amount)
return module