这篇blog是对之前写的一篇博客的补充:
网上有很多介绍如何计算卷积网络运算量的文章,基本都是介绍卷积还有全连接等一些常用的层是如何计算的,但很少有介绍反卷积层的运算量如何计算。因为我做分割比较多,一般分割的网络都会带有反卷积层,所以估算网络运算量的时候,都是需要算上反卷积的。写这篇文章的目的主要是为了告诉读者,卷积、反卷积、分组卷积和分组反卷积的运算量分别是如何计算出来的,虽然最终的计算公式都很简单,但是具体是如何估算的这里会介绍下。
本文相关代码,计算MXNet网络运算量的小工具:
普通卷积层运算量计算:
这个普通卷积层的运算量很多文章都已经讲过如何计算了,这里也重复讲下,为后面介绍反卷积的运算量做铺垫。卷积的运算一般的实现就是转化为矩阵乘法运算,首先来看一下卷积运算的简单的示意图:
首先左上角定义了输入和输出的feature map的形状,然后假设卷积核大小是(K, K),所以权值的形状就是 Cout * Cin * K * K,然后一般来说实现卷积的前向是通过首先对输入的feature map应用im2col操作,从 Cin * Hin * Win 形状的矩阵,转换成形状是 Cin * K * K * Hout * Wout 的矩阵,然后与权值相乘,就得到右边的输出。所以卷积前向的运算量看第一行就可以算出来了:
如果还有偏置的话,还要加上加偏置的运算量:
当然卷积运算的时候除了乘法还有加法,而我这了只算了乘法的次数,而且我看下面这个仓库的代码在计算运算量的时候也是只算了乘法。
https://github.com/albanie/convnet-burden
虽然计算运算量不会计算反向过程,但是卷积的反向和接下来要介绍和反卷积的前向是对应的,所以也简单说下,在反向过程中,求输入的梯度的时候把权值转置一下,然后与输出的梯度相乘就得到中间结果,然后再做一个col2im操作把中间结果回填到输入梯度矩阵的对应位置上。
普通反卷积:
然后我们来看下普通反卷积的运算量的计算方法,首先看一下反卷积前向和后向运算过程的示意图:
左上角也是定义了反卷积的输入与输出的feature map大小,然后这里反卷积的权值的形状与卷积有点不同,是 Cin * Cout * K * K,这是因为反卷积的前向和后向操作分别是对应卷积的后向和前向,也就是刚好反过来的。然后我们直接看反卷积的前向操作,和卷积的后向操作对应,权值做转置,然后与输入feature map做一个乘法,这里可以看成是一个1x1的卷积,输出通道数是 Cout * K * K,然后的到中间结果,然后再做一个col2im的操作回填到输出feature map对应的位置上。所以反卷积的运算量如下:
如果还有偏置的话则是:
所以如果有偏置存在的话,计算反卷积的运算量是需要知道输入与输出feature map大小的。
分组卷积:
分组卷积的运算量其实就是直接把卷积的运算量除以组数,比如分为g组,继续沿用上面卷积的运算量公式的话,那么分组卷积的运算量为:
加上偏置的话就是:
具体是怎么算出来的呢,直接看下面的示意图就应该很清晰了:
左上角定义了输入与输出feature map的大小还有卷积的分组数,则根据分组卷积的定义,输出feature map的通道Cout被分成了g组,每组里面的Cout / g个feature map链接输入的对应索引的 Cin / g 个通道的feature map,所以看上图,在把输入作im2col操作的时候也是按组来做的,每组都会生成一个 (Cin / g * K * K) * Hout * Wout的矩阵,然后与对应的权值做乘法,就是图中的相同颜色部分,然后每组做完乘法就得到了输出feature map,然后如果还有偏置,则是最后再加上,所以分组卷积的运算量就可以求到的了。
分组反卷积:
然后来看下反卷积,有了分组卷积的铺垫,分组反卷积也不难求,分组反卷积的FP同样也是对应分组卷积的BP:
同样的,左上角定义了分组反卷积的输入和输出feature map大小,分组数为g。同样的输出feature map的通道Cout被分成了g组,每组里面的Cout / g个feature map链接输入的对应索引的 Cin / g 个通道的feature map,然后在前向过程中,对于每组的计算,权值首先需要转置一下,得到 Cout / g * K * K * Cin / g的权值矩阵然后和输入对应的组数做乘法,然后得到输出对应的组的中间结果,然后每一组的中间结果再通过col2im回填到输出feature map对应的组的位置。
所以分组反卷积的运算量如下:
如果有偏置的话就是:
如果想更加详细的了解代码上的实现,读者可以参考MXNet中反卷积权值shape的推断部分,还有反卷积前向部分代码。
相关资料:
CNN 模型所需的计算力(flops)和参数(parameters)数量是怎么计算的?