Batch Normalization的好处我就不多说了,详细可看论文,其实老早之前就看过论文了,但无奈拖延症(加上使用Keras),所以对BN的代码具体实现(train和test阶段)不是很懂,所以在此记个笔记~~~~~~~~~~
简要说下:训练完成后的均值方差还只是最后一个batch的均值方差,所以测试的时候我们用训练时所有批次均值方差的滑动平均来作为测试的均值方差,区别就这些,实际操作还是看下面举例吧
参考代码:
https://github.com/soloice/mnist-bn(作者用的是TF-Slim)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
train_step = slim.learning.create_train_op(cross_entropy, optimizer, global_step=step)
###########################################################
# The list of values in the collection with the given name, or an empty list if
# no value has been added to that collection. The list contains the values in
# the order under which they were collected.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# 关键操作
# 我们要一次进行多个操作(训练节点和参数滑动平均更新),只不过滑动平均已经封
# 装好了,一次进行多个操作就会用到tf.control_dependencies和tf.group两种机制
# 来产生操作依赖关系(详情见我的另一篇笔记)
# 我试了以下三种形式,都可以使用
if update_ops:
updates = tf.group(*update_ops)
cross_entropy = control_flow_ops.with_dependencies([updates], cross_entropy)
if update_ops:
updates = tf.group(*update_ops)
cross_entropy = control_flow_ops.with_dependencies([updates], train_step)
with tf.control_dependencies([tf.group(*update_ops)]):
train_step = slim.learning.create_train_op(cross_entropy, optimizer, global_step=step)
tf.nn.batch_normalization()的用法,这个api是封装级别比较低的一个
def bacthnorm(inputs, scope, epsilon=1e-05, momentum=0.99, is_training=True):
inputs_shape = inputs.get_shape().as_list()
params_shape = inputs_shape[-1:]
axis = list(range(len(inputs_shape) - 1))
with tf.variable_scope(scope):
beta = create_bn_var("beta", params_shape,
initializer=tf.zeros_initializer())
gamma = create_bn_var("gamma", params_shape,
initializer=tf.ones_initializer())
# for inference
moving_mean = create_bn_var("moving_mean", params_shape,
initializer=tf.zeros_initializer(), trainable=False)
moving_variance = create_bn_var("moving_variance", params_shape,
initializer=tf.ones_initializer(), trainable=False)
if is_training:
mean, variance = tf.nn.moments(inputs, axes=axis)
update_move_mean = moving_averages.assign_moving_average(moving_mean,
mean, decay=momentum)
update_move_variance = moving_averages.assign_moving_average(moving_variance,
variance, decay=momentum)
tf.add_to_collection("_update_ops_", update_move_mean)
tf.add_to_collection("_update_ops_", update_move_variance)
else:
mean, variance = moving_mean, moving_variance
return tf.nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon)
我已实验过,也是基于mnist的,传送门
tf.layers.batch_normalization()也是一个封装级别比较高的API
# 举例,来自官网
x_norm = tf.layers.batch_normalization(x, training=training)
# ...
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
一般来说,这三个就够用了~~~~~~~~~~~