前面在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。
所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。
batch normalization 的实现非常简单,对于给定的一个 batch 的数据:则其公式为:
也就是说,BN是针对输入的整个数据来说的。
在对普通数值型数据进行BN时,由于其输入为(batsize, 特征数)。所以求均值就是对一个batch size中的所有数据进行均值计算,得到每一个特征的均值,标准差也是一样。
def simple_batch_norm_1d(x, gamma, beta):
eps = 1e-5
x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替
def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
print(x.shape)
eps = 1e-5
x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
if is_training:
x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
# 这里使用滑动平均
moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
else:
x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
对于二维卷积的输出,BN是计算每一个通道的平均值。这里用mxnet的代码展示这个过程:
from mxnet import nd
def pure_batch_norm(X, gamma, beta, eps=1e-5):
assert len(X.shape) in (2, 4)
# 全连接: batch_size x feature
if len(X.shape) == 2:
# 每个输入维度在样本上的平均和方差
mean = X.mean(axis=0)
variance = ((X - mean)**2).mean(axis=0)
# 2D卷积: batch_size x channel x height x width
else:
# 对每个通道算均值和方差,需要保持4D形状使得可以正确地广播
mean = X.mean(axis=(0,2,3), keepdims=True)
print(mean)
variance = ((X - mean)**2).mean(axis=(0,2,3), keepdims=True)
# 均一化
X_hat = (X - mean) / nd.sqrt(variance + eps)
# 拉升和偏移
return gamma.reshape(mean.shape) * X_hat + beta.reshape(mean.shape)