论文笔记-Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

论文原文:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate

1. 介绍

深度学习模型训练的主要方法是 Stochastic gradient descent (SGD) ,SGD 优化网络的参数\theta,从而最小化损失。

使用 mini-batch 近似估计损失函数对于参数的梯度,其估计的质量随着 batch size 的增大而提高,并且计算效率更高。

虽然随机梯度简单有效,但它需要仔细调整模型超参数,特别是优化中使用的学习率,以及模型参数的初始值。由于每一层的输入受所有前面层参数的影响,网络参数的微小变化都会随着网络变得更深而放大。

这种输入分布的变化会导致后面的每一层都要去适应一个新的分布,会导致整个网络学习变慢。

另一方面,当使用 sigmoid 的时候,随着网络层数加深,在训练中输入的分布很容易移动到饱和区域,使梯度变化很小,很容易产生梯度消失。通常的解决办法是使用 ReLU 作为非线性激活函数,并使用合适的初始化和较小的学习率。但是,如果可以让每一层的输入分布保持稳定以避免陷入饱和区,训练也会加速。

保持每一层输入分布的不变有利于使训练更有效。

作者将训练过程中深层网络内部节点分布的变化称为 internal covariante shift 。消除他可以显著加速神经网络的训练。方法就是作者提出的(现在神经网络标配的)Batch Normalization。

使用 BN 可以:

  • 减少参数梯度对于参数范围或者初始值的依赖,使得我们能使用更高的学习率
  • 减少了 Dropout 的需求
  • 使用 Sigmoid 等有饱和区域的非线性函数时,可以避免网络陷入饱和

2. 关于减少 internal covariante shift

作者定义 Internal Covariate Shift 作为训练过程中网络参数变化引起的网络激活分布的变化。为了改进训练,我们寻求减少 internal covariante shift 。在训练过程中,将每层输入的分布固定,以提高训练速度。

这一点已经被 LeCun 在1998年提出了,如果输入被白化(whiten),即线性变换为具有零均值和单位方差,并且去相关,网络训练收敛得更快。

通过对每一层的输入进行白化,我们将朝着实现输入的固定分布迈出一步,从而消除 internal covariante shift 的不良影响。

我们希望得到的是,对于任何参数值,网络总是可以得到需要的数据分布,这样就可以让损失函数对于模型参数的梯度也与 normalization 有关。假设 x 是某一层的输入, \chi 是整个训练集, normalization 可以写作 \hat{x} = Norm(x, \chi)

其中 x, \chi 都与模型参数 \theta 有关(因为是某一层的输出),因此在反向传播的时候需要计算梯度
\frac{\partial Norm(x,\chi )}{ \partial x} , \frac{\partial Norm(x,\chi)}{ \partial \chi }

由于白化计算需要计算协方差矩阵来进行去相关,计算成本高。一些研究也表明了只是用统计学的方法对单个训练样本或者某个位置的特征图进行处理,发现这样会改变数据的表达能力,因为它抛弃了数据范围的信息(discarding the absolute scale of activations)。 因此我们需要一个能够在网络中保留信息的 Normalizatin 操作。

3. Normalization Statistics via Mini-Batch

完全对每一层的输入进行白化操作计算昂贵且不是处处可微分的,因此作者进行了两点简化:

  • 分别对每个特征 normalize, 让它均值为0,方差为1
  • 由于使用了 mini-batch 训练, 作者利用每个 mini-batch , 而不是整个数据集,的激活值来估计均值和方差

简单的 normalization 会改变输入数据的表达能力,比如对于 sigmoid 函数, 会让输入被限制在线性区域,失去非线性。为此。作者加入了一个变换操作,以确保网络可以实现一个恒等变换(identity transform)。

对于每个激活值 x^{(k)}, 有y^{(k)} = \gamma^{(k)} \hat{x}^{(k)}+\beta^{(k)}

\gamma^{(k)} = \sqrt{Var[x^{(k)}]}, \beta^{[k]} = E[x^{(k)}] 时, 我们可以还原数据。

Batch Normalizing TRansforrm

BN 是可微的变换,在反向传播的时候可以计算各个参数的梯度。

求梯度

3.1 训练与推理

在训练阶段使用 BN 可以加速训练, 但是在推理阶段(验证、测试)可能没有 batch。

利用训练时得到的mini-batch 的统计值, m为 batch size,对所有 batch 的均值方差做平均:
E[x] \leftarrow E_B[ \mu_B], Var[x] \leftarrow \frac{m}{m-1}E_B[\sigma^2_B]

然后对测试的数据进行 normalization

y = \gamma \frac{x-E[x]}{\sqrt{Var[x]+\epsilon }}+ \beta

训练流程

3.2 卷积层BN

假设一个仿射变换 z =g(Wu+b), 其中W,b 是可学习的参数,这个形式可以包括全连接层和卷积层。g(\cdot) 代表非线性函数。

作者将 BN 直接加载非线性函数之前,也就是 normalize x= Wu+b。 原因是Wu+b 更有可能有一个对称的,非稀疏的分布,在这里做 normalization 更有可能得到一个稳定的分布。

因为 normalize了Wu+b,偏差b可以被忽略掉,因为它的影响将被随后的均值减法抵消,z= g(Wu+b) 被替换为 z=g(BN(Wu)),其中 BN 变换独立应用于 x = Wu 的每个维度,每个维度有一对单独的学习参数 γ^{(k)},β^{(k)}

3.3 BN可实现更高的学习率

在传统的深度网络中,过高的学习率可能会导致梯度爆炸或消失,以及陷入不良的局部最小值。批量标准化有助于解决这些问题。通过对整个网络的激活进行标准化,它可以防止参数的微小变化放大为梯度激活的较大和次优变化;比如,它可以防止训练陷入非线性的饱和状态。

BN 还使训练对参数规模更具弹性。通常,较大的学习率可能会增加层参数的规模,会放大反向传播过程中的梯度,并导致模型爆炸。然而,通过BN,通过层的反向传播不受其参数规模的影响。

假设有 BN((aW)u),在反向传播中 \frac{\partial BN((aW)u)}{\partial u} = \frac{1}{a} \cdot \frac{\partial BN(Wu)}{\partial W}

规模不影响每层雅可比矩阵,也不影响梯度传播。此外,较大的权重会导致较小的梯度,批量归一化将稳定参数增长。

3.4 Batch Normalization 正则化

模型使用 Batch Normalization 进行训练时,会看到一个训练样本与 mini-batch 中的其他样本相结合,并且训练网络不再为给定的训练示例生成确定性值。在作者的实验中,发现这种效果有利于网络的泛化。虽然 Dropout 通常用于减少过拟合,但在 BN 网络中,我们发现它可以被移除或减少使用。

4. pytorch 中的 BN 层

pytorch 中提供了 BatchNorm1dBatchNorm2dBatchNorm3d,这里以BatchNorm2d 为例子,参数包括:

  • num_features: 输入的通道数,如果输入大小为 (N,C,H,W),这里的值应该是C
  • eps: 一个保持计算中数值稳定的值,默认值 1e-5
  • momentum: 一个用来计算 running_mean 和 running_var 的值,可以设置为 None,会变为简单的计算平均,默认值为 0.1
  • affine: 布尔值,为 True 时 BN 层会有两个可学习的参数, 也就是论文中的 \gamma, \beta,默认值为 True
  • track_running_stats: 布尔值,为 True 时 会追踪计算 running_mean 和 running_var;为 False 时,不追踪统计信息,并将缓冲区中的 running_mean 和 running_var 初始化为 None。当它们为 None 时,BN 在训练和推理阶段dou会只使用 batch 的统计信息,默认值为 True

输出 size 与输入相同,都是 (N,C,H,W)。

关于momentum,与 optimizer 中控制梯度下降的 momentum 不同, 这里的momentum 是用在计算统计量的,\hat{x}_{new}= (1 - momentum) \times \hat{x}+ momentu \times x_t,其中 \hat{x}是估计的整个数据集的统计量,x_t是新观测到的值,也就是当前 mini-batch 计算得到的统计量。

这个公式可以这样理解,在上面的 algorithm2 中,推理阶段使用的 E[x], Var[x]需要每一个Batch的均值和方差,训练完成之后按照公式计算平均。对于 E[x],在第n个 batch时更新 E[x]= \frac{E[x] \times n + E_B}{n+1} = \frac{n}{n+1}E[x]+\frac{1}{n+1}E_B

momentum = \frac{1}{n+1}时, 就是上面的计算平均的公式。


参考:
https://zhuanlan.zhihu.com/p/50444499
https://blog.csdn.net/qq_37524214/article/details/108559989
https://blog.csdn.net/APTX2334869/article/details/102716147
https://blog.csdn.net/joyce_peng/article/details/103163048

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 218,546评论 6 507
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,224评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,911评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,737评论 1 294
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,753评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,598评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,338评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,249评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,696评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,888评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,013评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,731评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,348评论 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,929评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,048评论 1 270
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,203评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,960评论 2 355

推荐阅读更多精彩内容