python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > Pytorch BatchNorm2d参数

Pytorch中关于BatchNorm2d的参数解释

作者:顾北向南

这篇文章主要介绍了Pytorch中关于BatchNorm2d的参数解释,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

BatchNorm2d中的track_running_stats参数

如果BatchNorm2d的参数val,track_running_stats设置False,那么加载预训练后每次模型测试测试集的结果时都不一样;

track_running_stats设置为True时,每次得到的结果都一样。

running_mean和running_var参数

running_mean和running_var参数是根据输入的batch的统计特性计算的,严格来说不算是“学习”到的参数,不过对于整个计算是很重要的。

torch.nn.BatchNorm1d(num_features,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True)

BatchNorm2d参数讲解

一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层或者Dropout层。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。

同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。容易出现问题也正好是这三个参数:trainning,affine,track_running_stats。

其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False则γ=1,β=0 \gamma=1,\beta=0γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True。

trainning和track_running_stats,track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:
阅读全文