由于实验需要,便用pytorch函数手动实现了batchnorm函数,但是最后发现结果不对,最后在Pytorch论坛上找到了相关解决办法!
基础
前期实现
上述博客给出了python实现代码,我将其中的numpy函数改成了pytorch的相关函数:
def fowardbn(x, gam, beta, ):
'''
x:(N,D)维数据
'''
momentum = 0.1
eps = 1e-05
running_mean = 0
running_var = 1
running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)
running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)
mean = x.mean(dim=0)
var = x.var(dim=0)
# bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).data
x_hat = (x - mean) / torch.sqrt(var + eps)
out = gam * x_hat + beta
cache = (x, gam, beta, x_hat, mean, var, eps)
return out, cache
然后与nn.BatchNorm1d计算的结果比较:
model2 = nn.BatchNorm1d(5)
input1 = torch.randn(3, 5, requires_grad=True)
input2 = input1.clone().detach().requires_grad_()
x = model2(input1)
out, cache = fowardbn(input2, model2.weight, model2.bias) # 使用相同的尺度变换量
发现结果x和out的值不一样。
然后就不停的找问题是不是实现方法有差别。
在官方论坛上找到了,有人遇到了相同的问题,官方人员给了答复,还提供了一个官方的实现版本。
Pytorch的论坛做的还是挺不错的。
问题
我发现官方实现的代码中
var = input.var([0, 2, 3], unbiased=False)
在求输入的方差时,多了一个参数设置unbiased=False
,不懂。
我又查看了一下Pytorch的代码文档:
torch.var(input, unbiased=True) → Tensor
Returns the variance of all elements in the
input
tensor.
Ifunbiased
isFalse
, then the variance will be calculated via the biased estimator. Otherwise, Bessel’s correction will be used.
意思是unbiased = False
时,通过无偏估计计算,反之则通过贝塞尔矫正方法计算。可用如下图片总结:
这是统计方面的知识了,可以参考此博客。
最终实现代码
将初始代码中方差计算加上参数unbiased = False
,结果正确,完整代码如下
def fowardbn(x, gam, beta, ):
momentum = 0.1
eps = 1e-05
running_mean = 0
running_var = 1
running_mean = (1 - momentum) * running_mean + momentum * x.mean(dim=0)
running_var = (1 - momentum) * running_var + momentum * x.var(dim=0)
mean = x.mean(dim=0)
var = x.var(dim=0,unbiased=False)
# bnmiddle_buffer = (input - mean) / ((var + eps) ** 0.5).data
x_hat = (x - mean) / torch.sqrt(var + eps)
out = gam * x_hat + beta
cache = (x, gam, beta, x_hat, mean, var, eps)
return out, cache
model2 = nn.BatchNorm1d(5)
input1 = torch.randn(3, 5, requires_grad=True)
input2 = input1.clone().detach().requires_grad_()
x = model2(input1)
out, cache = fowardbn(input2, model2.weight, model2.bias)
Reference
Batch Normalization 学习笔记
Batch Normalization梯度反向传播推导
PyTorch论坛问题
官方人员给的batchnorm2d的手动实现代码
方差的贝塞尔校正