原文链接:https://arxiv.org/pdf/2104.08500.pdf 发表:Arxiv
code:无
编辑:牛涛
据我所知第一篇在cv领域中的transformer剪枝,主要是通过给MHSA和MLP添加mask并进行L1正则实现稀疏化。
transformer中,MHSA及紧接的MLP可以写成如下
其中X是n*d维度的输入(n表示特征的个数,d表示特征的维度),当经过MHSA得到Y后,用两个MLP处理得到这个block的输出Z
作者在n*d中d这个维度进行稀疏化,认为每个特征只需要小于d个元素就可以表示了。作者通过在每个1*d上乘以一个二进制向量来实现置零。写成公式的话如下
a*是二进制向量,扩展成对角线矩阵再右乘X,效果就等同于mask掉d维中的对应元素。但是直接通过反向传播优化a*比较困难(因为是二进制离散的),所以作者把a*松弛到了实数域,并通过在损失函数中添加关于a的L1范数进行正则化。
当网络收敛后,根据a的大小在给定阈值下剪枝,并对剪枝后网络微调以恢复精度。mask在MHSA和MLP中都进行了添加,剪枝过程写成如下公式
文章还给出了示意图
值得一提的是,剪枝都是发生在MLP之前,不知道放在后边会不会有什么影响?