个人认为BatchNormalize是一个非常重要但是却很容易被忽略的知识点,目前几乎所有的神经网络都会用到。我在用cifar10数据集测试时,发现同样的网络,有bn要比没有bn层的验证集准确率提高10%左右。这也验证了吴恩达老师在课中所讲的bn层会有轻微的正则化效果。
class BatchNormalize(tf.keras.layers.Layer):
def __init__(self, name='BatchNormal', **kwargs):
super(BatchNormalize, self).__init__(name=name, **kwargs)
self._epsilon = 0.001
self._decay = 0.99
def build(self, input_shape):
self._mean = self.add_weight(name='mean', shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.zeros_initializer(), trainable=False)
self._variance = self.add_weight(name="variance", shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.ones_initializer(), trainable=False)
self._gamma = self.add_weight(name='gamma', shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.ones_initializer(), trainable=True)
self._beta = self.add_weight(name="beta", shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.zeros_initializer(), trainable=True)
self._axes = [0, 1, 2]
if len(input_shape) == 2:
self._axes = [0]
def call(self, inputs, training=None):
if training:
batch_mean, batch_variance = tf.nn.moments(inputs, axes=self._axes, keep_dims=False, name='moment')
train_mean = self._mean.assign(tf.add(tf.multiply(self._mean, self._decay), tf.multiply(batch_mean, tf.math.subtract(1.0, self._decay))))
train_variance = self._variance.assign(tf.add(tf.multiply(self._variance, self._decay), tf.multiply(batch_variance, tf.math.subtract(1.0, self._decay))))
with tf.control_dependencies([train_mean, train_variance]):
return tf.nn.batch_normalization(inputs, batch_mean, batch_variance, self._beta, self._gamma, self._epsilon, name="batch_normal")
else:
return tf.nn.batch_normalization(inputs, self._mean, self._variance, self._beta, self._gamma, self._epsilon)