这里提到的关于BN层的使用方法是基于TensorFlow框架的,不过其他框架也类似,原理是一样的。
关于BN
what is BN?
Batch Normalization是由google提出的一种训练优化方法。参考论文:Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift
Normalization是数据标准化(归一化,规范化),Batch 可以理解为批量,加起来就是批量标准化。
先说Batch是怎么确定的。在CNN中,Batch就是训练网络所设定的图片数量batch_size。
why is BN?
BN 解决的问题是梯度消失与梯度爆炸。
在深度网络中,如果网络的激活输出很大,其梯度就很小,学习速率就很慢。假设每层学习梯度都小于最大值0.25,网络有n层,因为链式求导的原因,第一层的梯度小于0.25的n次方,所以学习速率就慢,对于最后一层只需对自身求导1次,梯度就大,学习速率就快。
这会造成的影响是在一个很大的深度网络中,浅层基本不学习,权值变化小,后面几层一直在学习,结果就是,后面几层基本可以表示整个网络,失去了深度的意义。
关于梯度爆炸,根据链式求导法,
第一层偏移量的梯度=激活层斜率1x权值1x激活层斜率2x…激活层斜率(n-1)x权值(n-1)x激活层斜率n
假如激活层斜率均为最大值0.25,所有层的权值为100,这样梯度就会指数增加。
如何在tensorflow 中使用 BN
建议的使用方式
- 训练阶段BN中的is_training设为True
- 模型保存时把BN中的参数也一并保存,主要是moving_mean和moving_variance相关的参数名称
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
# 设置保存模型
var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list, max_to_keep=5)
- 预测阶段BN设为False
小技巧
- 如果你的待预测数据量比较大,每次都是一大批量的数据同时预测,可以设为训练模型,此时会直接从待预测数据计算其对应的均值和方差。
- 如果是单个数据或数据分布差异不大,建议在训练阶段保存BN 层的参数,特别是移动均值和方差,在预测阶段设为预测模型,此时会使用训练阶段的mean和var。
原因说明
若在预测阶段BN的is_training设为True
此时当改变需要预测数据的batchsize时预测的label也跟着变,因为使用的是该batch中的数据进行标准化操作,当预测的batchsize越大,假如你的预测数据集和训练数据集的分布一致,结果就越接近于训练结果,但如果batchsize=1,那BN层就发挥不了作用,结果很难看。若在预测阶段BN的is_training设为False
那如果在预测时is_traning=false呢,但BN层的参数没有从训练中保存,那使用的就是随机初始化的参数,结果不堪想象。