上一期介绍了Batch Normalization的前向传播,然而想法美好,然而能否计算、如何计算这些新参数才是重点。
系列目录
理解Batch Normalization系列1——原理
理解Batch Normalization系列2——训练及评估
理解Batch Normalization系列3——为什么有效及若干讨论
理解Batch Normalization系列4——实践
本文目录
1 训练阶段
1.1 反向传播
1.2 参数的初始化及更新
2 评估阶段
2.1 来自训练集的均值和方差
2.2 评估阶段的计算
3 总结
参考文献
先放出这张图,帮助记住。
图 1. BN的结构
1 训练阶段
引入BN,增加了、、、四个参数。
这四个参数的引入,能否计算梯度?它们分别是如何初始化与更新?
1.1 反向传播
神经网络的训练,离不开反向传播,必须保证BN的标准化、缩放平移两个操作必须可导。
缩放平移就是一个线性公式,求导很简单。而对于标准化时的统计量,看起来有点无从下手。其实是凭借图1的变量关系,可以绘制计算图,如图2所示。Frederik Kratzert 在这篇博文中有详细的计算,对每一个环节都进行了详细的描述。
图 2. 求解BN反向传播的计算图 (来源: 这篇博文)
由图2可见:
- 每个环节都可导
- 只要求出各个环节的导数
- 用链式法则(串联关系就相乘,并联关系就相加)求出总梯度。
狗尾续貂,对这个反传大致做了一个流程图,如图3所示,帮助理解。
图 3. BN层反传的流程图 (来源: 这篇博文)
注意,均值的梯度、方差的梯度的计算,只是为了保证梯度的反向传播链路的通畅,而不是为了更新自己(没明白下文还会解释);缩放因子和j和平移因子的梯度传播则和权重W一样,不影响反向传播链路的通畅,只是为了更新自己。
最后的结果就是原论文中表述:
图4. BN的反向传播. (来源: Batch Normalization Paper)
如果是从事学术,不妨练练手。
1.2 参数的初始化及更新
讨论一下图1中的6个参数的初始化及更新问题。
-
W
初始化用标准正态分布,更新用梯度下降。
与经典网络的初始化相同,初始化一个标准正态分布(即Xavier方法)。
-
b
省略掉该参数。
在经典的神经网络里,b作为偏置,用于解决那些W无法通过与x相乘搞定的"损失减少要求",即对于本层所有神经元的加权和进行各自的平移。而加入BN后,的作用正是进行平移。b的作用被所完全替代了,因此省略掉b。
了解过ResNet结构的朋友会发现该网络中的卷积,都没有偏置,为什么?下面截图是Kaiming He在github上回答原话。(踩坑无数必须体会深刻)
图5. BN的加入导致本层的偏置b失效
-
和
初始化取决于统计量,仅更新梯度,但不更新值本身。
在训练阶段,每个mini-batch上进行前向传播时,通过对本batch上的m个样本进行统计得到;
在反向传播时,计算出它们的梯度对的梯度、对的梯度,用于进行梯度传播。
但是和这两个值本身不必进行更新,因为在下一个mini-batch会计算自己的统计量,所以前一个mini-batch获得的和没意义。 -
和
初始化为1、0,更新用梯度下降。
根据我们在《理解Batch Normalization系列1——原理》的解读,作为“准方差”,初始化为一个全1向量;而作为"准均值”,初始化为一个全0向量,他俩的初始值对于刚刚完成标准正态化的来说,没起任何作用。
至于将要变成什么值,起多大作用,那就交给后续的训练。即采用梯度下降进行更新,方式同。
2 评估阶段
、是在整个训练集上训练出来的,与一样,训练结束就可获得。
然而,和是靠每一个mini-batch的统计得到,因为评估时只有一条样本,batch_size相当于是1,在只有1个向量的数据组上进行标准化后,成了一个全0向量,这可咋办?
2.1 来自训练集的均值和方差
做法是用训练集来估计总体均值和总体标准差。
-
简单平均法
把每个mini-batch的均值和方差都保存下来,然后训练完了求均值的均值,方差的均值即可。
-
移动指数平均(Exponential Moving Average)
这是对均值的近似。
仅以举例:
其中decay是衰减系数。即总均值是前一个mini-batch统计的总均值和本次mini-batch的加权求和。至于衰减率 decay在区间之间,decay越接近1,结果越稳定,越受较远的大范围的样本影响;decay越接近0,结果 越波动,越受较近的小范围的样本影响。
事实上,简单平均可能更好,简单平均本质上是平均权重,但是简单平均需要保存所有BN层在所有mini-batch上的均值向量和方差向量,如果训练数据量很大,会有较可观的存储代价。移动指数平均在实际的框架中更常见(例如tensorflow),可能的好处是EMA不需要存储每一个mini-batch的值,永远只保存着三个值:总统计值、本batch的统计值,decay系数。
在训练阶段同步获得了和后,在评估时即可对样本进行BN操作。
2.2 评估阶段的计算
为避免分母不为0,增加一个非常小的常数,并为了计算优化,被转换为:
这样,只要训练结束,就已知了,1个BN层对一条测试样本的前向传播只是增加了一层线性计算而已。
3 总结
用图6做个总结。
图6. BN层相关参数的学习方法
鬼斧神工的构造,鬼斧神工的参数获取方法,这么多鬼斧神工,需要好好消化消化。
请见下一期《理解Batch Normalization系列3——为什么有效及若干讨论》
参考文献
[1] https://arxiv.org/pdf/1502.03167v3.pdf
[2] https://r2rt.com/implementing-batch-normalization-in-tensorflow.html
[3] Adjusting for Dropout Variance in Batch Normalization and Weight Initialization
[4] https://www.jianshu.com/p/05f3e7ddf1e1
[8] https://panxiaoxie.cn/2018/07/28/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0-Batch-Normalization/
[9] https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization