参数在训练过程中被更改,也就是说,它们是在神经网络训练过程中学习到的东西,但是什么是缓冲区?
在神经网络训练中是学到的吗?
Pytorch doc用于register_buffer()
方法读取
通常用于注册不应被视为模型参数的缓冲区。例如,BatchNorm
running_mean
不是参数,而是持久状态的一部分。
如您所见,在训练过程中使用SGD学习并更新了模型参数。
但是,有时还有其他数量属于模型“状态”的一部分,应
另存为state_dict
。
-移至模型的其余参数cuda()
或cpu()
与之一起使用。
-转换成float
/ half
/ double
与模型的参数的其余部分。
将这些“参数”注册为模型buffer
可以使pytorch跟踪它们并像常规参数一样保存它们,但是可以防止pytorch使用SGD机制更新它们。
用于缓冲的一个例子中可以找到_BatchNorm
模块,其中running_mean
,running_var
并num_batches_tracked
通过累积通过所述层转发的数据的统计信息被登记为缓冲器和更新。这与使用常规SGD优化学习数据的仿射变换的参数weight
和bias
参数相反。