1> cleros..:
这是一个很好的问题,您已经给出了不错的答案。但是,它重塑了轮子-一个非常优雅的Pytorch内部例程,使您无需花费太多精力即可完成此操作-并且适用于任何网络。
这里的核心概念是PyTorch的state_dict
。状态字典有效地包含parameters
由nn.Modules
,及其子模块等的关系给出的树结构组织的结构。
简短的答案
如果只希望代码使用来将值加载到张量中state_dict
,请尝试以下行(其中dict
包含有效state_dict
):
`model.load_state_dict(dict, strict=False)`
strict=False
如果只想加载一些参数值,在哪里至关重要。
长答案-包括对PyTorch的介绍 state_dict
这是一个状态字典如何查找GRU的示例(我选择input_size = hidden_size = 2
以便可以打印整个状态字典):
rnn = torch.nn.GRU(2, 2, 1)
rnn.state_dict()
# Out[10]:
# OrderedDict([('weight_ih_l0', tensor([[-0.0023, -0.0460],
# [ 0.3373, 0.0070],
# [ 0.0745, -0.5345],
# [ 0.5347, -0.2373],
# [-0.2217, -0.2824],
# [-0.2983, 0.4771]])),
# ('weight_hh_l0', tensor([[-0.2837, -0.0571],
# [-0.1820, 0.6963],
# [ 0.4978, -0.6342],
# [ 0.0366, 0.2156],
# [ 0.5009, 0.4382],
# [-0.7012, -0.5157]])),
# ('bias_ih_l0',
# tensor([-0.2158, -0.6643, -0.3505, -0.0959, -0.5332, -0.6209])),
# ('bias_hh_l0',
# tensor([-0.1845, 0.4075, -0.1721, -0.4893, -0.2427, 0.3973]))])
因此state_dict
网络的所有参数。如果我们有“嵌套” nn.Modules
,我们将得到由参数名称表示的树:
class MLP(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)
self.lin_a = torch.nn.Linear(2, 2)
self.lin_b = torch.nn.Linear(2, 2)
mlp = MLP()
mlp.state_dict()
# Out[23]:
# OrderedDict([('lin_a.weight', tensor([[-0.2914, 0.0791],
# [-0.1167, 0.6591]])),
# ('lin_a.bias', tensor([-0.2745, -0.1614])),
# ('lin_b.weight', tensor([[-0.4634, -0.2649],
# [ 0.4552, 0.3812]])),
# ('lin_b.bias', tensor([ 0.0273, -0.1283]))])
class NestedMLP(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)
self.mlp_a = MLP()
self.mlp_b = MLP()
n_mlp = NestedMLP()
n_mlp.state_dict()
# Out[26]:
# OrderedDict([('mlp_a.lin_a.weight', tensor([[ 0.2543, 0.3412],
# [-0.1984, -0.3235]])),
# ('mlp_a.lin_a.bias', tensor([ 0.2480, -0.0631])),
# ('mlp_a.lin_b.weight', tensor([[-0.4575, -0.6072],
# [-0.0100, 0.5887]])),
# ('mlp_a.lin_b.bias', tensor([-0.3116, 0.5603])),
# ('mlp_b.lin_a.weight', tensor([[ 0.3722, 0.6940],
# [-0.5120, 0.5414]])),
# ('mlp_b.lin_a.bias', tensor([0.3604, 0.0316])),
# ('mlp_b.lin_b.weight', tensor([[-0.5571, 0.0830],
# [ 0.5230, -0.1020]])),
# ('mlp_b.lin_b.bias', tensor([ 0.2156, -0.2930]))])
那么-如果您不想提取状态dict而是更改它,从而更改网络参数怎么办?使用nn.Module.load_state_dict(state_dict, strict=True)
(链接到文档)此方法允许您将具有任意值的整个state_dict加载到相同类型的实例化模型中,只要键(即参数名称)正确且值(即参数)torch.tensors
为正确的形状。如果将strict
kwarg设置为True
(默认值),则加载的dict必须与原始状态dict完全匹配,但参数值除外。也就是说,每个参数必须有一个新值。
对于上面的GRU示例,我们需要每个的正确大小的张量(以及正确的设备btw)'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0'
。由于有时我们只想加载一些值(就像我想的那样),我们可以将strict
kwarg 设置为False
-,然后仅加载部分状态dict,例如仅包含的参数值的dict 'weight_ih_l0'
。
作为实用建议,我将简单地创建要向其中加载值的模型,然后打印状态字典(或至少打印键列表和各自的张量大小)
print([k, v.shape for k, v in model.state_dict().items()])
这告诉您要更改参数的确切名称。然后,您只需使用相应的参数名称和张量创建状态dict并加载它:
from dollections import OrderedDict
new_state_dict = OrderedDict({'tensor_name_retrieved_from_original_dict': new_tensor_value})
model.load_state_dict(new_state_dict, strict=False)