BN层介绍
BN,全称Batch Normalization,是2015年提出的一种方法,在进行深度网络训练时,大都会采取这种算法。
原文链接:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
BN解决的问题:深度神经网络随着网络深度加深,训练起来越困难,收敛越来越慢。
这个问题出现的原因:
深度神经网络涉及到很多层的叠加,而每一层的参数更新会导致上层的输入数据分布发生变化,通过层层叠加,高层的输入分布变化会非常剧烈,这就使得高层需要不断去重新适应底层的参数更新。为了训好模型,我们需要非常谨慎地去设定学习率、初始化权重、以及尽可能细致的参数更新策略。
Google 将这一现象总结为 Internal Covariate Shift,简称 ICS.
机器学习领域有个很重要的假设:独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障。而ICS现象的存在,导致输入的分布老是变化,不符合独立同分布的假设,因此网络模型很难稳定的去学习。
那BatchNorm的作用是什么呢?BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的
因此,BN的基本思想其实相当直观:因为深层神经网络在做非线性变换前的输入值(就是那个y=Wx+B,x是输入)随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近,所以这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因,而BN就是通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到均值为0方差为1的标准正态分布,其实就是把越来越偏的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,网络的输出就不会很大,可以得到比较大的梯度,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。
BN层的实现
从论文中给出的伪代码可以看出来BN层的计算流程是:
1.计算样本均值。
2.计算样本方差。
3.样本数据标准化处理。
4.进行平移和缩放处理。引入了γ和β两个参数。来训练γ和β两个参数。引入了这个可学习重构参数γ、β,让我们的网络可以学习恢复出原始网络所要学习的特征分布。
BN就是对不同样本的同一特征做归一化。
BN层的作用
BN层的作用主要有三个:
1.加快网络的训练和收敛的速度;
2.控制梯度爆炸防止梯度消失;
3.防止过拟合。
接下来就分析一下为什么BN层有着三个作用。
加快网络的训练和收敛的速度
在深度神经网络中中,如果每层的数据分布都不一样的话,将会导致网络非常难收敛和训练,而如果把 每层的数据都在转换在均值为零,方差为1 的状态下,这样每层数据的分布都是一样的训练会比较容易收敛。
控制梯度爆炸防止梯度消失
梯度消失:在深度神经网络中,如果网络的激活输出很大,其对应的梯度就会很小,导致网络的学习速率就会很慢,假设网络中每层的学习梯度都小于最大值0.25,网络中有n层,因为链式求导的原因,第一层的梯度将会小于0.25的n次方,所以学习速率相对来说会变的很慢,而对于网络的最后一层只需要对自身求导一次,梯度就大,学习速率就会比较快,这就会造成在一个很深的网络中,浅层基本不学习,权值变化小,而后面几层网络一直学习,后面的网络基本可以表征整个网络,这样失去了深度的意义。(使用BN层归一化后,网络的输出就不会很大,梯度就不会很小)
梯度爆炸:第一层偏移量的梯度=激活层斜率1x权值1x激活层斜率2x…激活层斜率(n-1)x权值(n-1)x激活层斜率n,假如激活层斜率均为最大值0.25,所有层的权值为100,这样梯度就会指数增加。(使用bn层后权值的更新也不会很大
防止过拟合
在网络的训练中,BN的使用使得一个minibatch中所有样本都被关联在了一起,因此网络不会从某一个训练样本中生成确定的结果,即同样一个样本的输出不再仅仅取决于样本的本身,也取决于跟这个样本同属一个batch的其他样本,而每次网络都是随机取batch,这样就会使得整个网络不会朝这一个方向使劲学习。一定程度上避免了过拟合。
为什么BN层一般用在线性层和卷积层后面,而不是放在非线性单元后
原文中是这样解释的,因为非线性单元的输出分布形状会在训练过程中变化,归一化无法消除他的方差偏移,相反的,全连接和卷积层的输出一般是一个对称,非稀疏的一个分布,更加类似高斯分布,对他们进行归一化会产生更加稳定的分布。其实想想也是的,像relu这样的激活函数,如果你输入的数据是一个高斯分布,经过他变换出来的数据能是一个什么形状?小于0的被抑制了,也就是分布小于0的部分直接变成0了,这样不是很高斯了。
参考文献:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
Batch Normalization原理与实战