参考链接:https://www.cnblogs.com/leebxo/p/10880399.html
BatchNorm2d
中的track_running_stats
参数
- 如果
BatchNorm2d
的参数val,track_running_stats
设置False
,那么加载预训练后每次模型测试测试集的结果时都不一样;track_running_stats
设置为True
时,每次得到的结果都一样。
running_mean
和running_var
参数
-
running_mean
和running_var
参数是根据输入的batch
的统计特性计算的,严格来说不算是“学习”到的参数,不过对于整个计算是很重要的。
torch.nn.BatchNorm1d(num_features,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True)
BatchNorm2d
参数讲解
- 一般来说pytorch中的模型都是继承
nn.Module
类的,都有一个属性trainning
指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN
层或者Dropout
层。通常用model.train()
指定当前模型model
为训练状态,model.eval()
指定当前模型为测试状态。
- 同时,
BN
的API中有几个参数需要比较关心的,一个是affine
指定是否需要仿射,还有个是track_running_stats
指定是否跟踪当前batch
的统计特性。容易出现问题也正好是这三个参数:trainning
,affine
,track_running_stats
。
- 其中的
affine
指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False
则γ=1,β=0 \gamma=1,\beta=0γ=1,β=0
,并且不能学习被更新。一般都会设置成affine=True
。
-
trainning
和track_running_stats
,track_running_stats=True
表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch
的统计特性。相反的,如果track_running_stats=False
那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False
,此时如果batch_size
比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。