热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Pytorch自由载入部分模型参数并冻结

Pytorch的load方法和load_state_dict方法只能较为固定的读入参数文件,他们要求读入的state_dict的key和Model.state_dict()的key

Pytorch的load方法和load_state_dict方法只能较为固定的读入参数文件,他们要求读入的state_dict的key和Model.state_dict()的key对应相等。

而我们在进行迁移学习的过程中也许只需要使用某个预训练网络的一部分,把多个网络拼和成一个网络,或者为了得到中间层的输出而分离预训练模型中的Sequential 等等,这些情况下。传统的load方法就不是很有效了。

例如,我们想利用Mobilenet的前7个卷积并把这几层冻结,后面的部分接别的结构,或者改写成FCN结构,传统的方法就不奏效了。

最普适的方法是:构建一个字典,使得字典的keys和我们自己创建的网络相同,我们再从各种预训练网络把想要的参数对着新的keys填进去就可以有一个新的state_dict了,这样我们就可以load这个新的state_dict,目前只能想到这个方法应对较为复杂的网络变换。

网上查“载入部分模型”,“冻结部分模型”一般都是只改个FC,根本没有用,初学的时候自己写state_dict也踩了一些坑,发出来记录一下。

一.载入部分预训练参数

我们先看看Mobilenet的结构

( 来源github,附带预训练模型mobilenet_sgd_rmsprop_69.526.tar)

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
)
def conv_dw(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)
self.model = nn.Sequential(
conv_bn( 3, 32, 2),
conv_dw( 32, 64, 1),
conv_dw( 64, 128, 2),
conv_dw(128, 128, 1),
conv_dw(128, 256, 2),
conv_dw(256, 256, 1),
conv_dw(256, 512, 2),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 1024, 2),
conv_dw(1024, 1024, 1),
nn.AvgPool2d(7),
)
self.fc = nn.Linear(1024, 1000)
def forward(self, x):
x = self.model(x)
x = x.view(-1, 1024)
x = self.fc(x)
return x

我们只需要前7层卷积,并且为了方便日后concate操作,我们把Sequential拆开,成为下面的样子

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
)
def conv_dw(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)

self.conv1 = conv_bn( 3, 32, 2)
self.conv2 = conv_dw( 32, 64, 1)
self.conv3 = conv_dw( 64, 128, 2)
self.conv4 = conv_dw(128, 128, 1)
self.conv5 = conv_dw(128, 256, 2)
self.conv6 = conv_dw(256, 256, 1)
self.conv7 = conv_dw(256, 512, 2)

# 原来这些不要了
# 可以自己接后面的结构
''' self.features = nn.Sequential( conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 1024, 2), conv_dw(1024, 1024, 1), nn.AvgPool2d(7),) self.fc = nn.Linear(1024, 1000) '''

def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
x4 = self.conv4(x3)
x5 = self.conv5(x4)
x6 = self.conv6(x5)
x7 = self.conv7(x6)
#x8 = self.features(x7)
#out = self.fc
return (x1,x2,x3,x4,x4,x6,x7)

我们更具改过的结构创建一个net,看看他的state_dict和我们预训练文件的state_dict有啥区别

net = Net()
#我的电脑没有GPU,他的参数是GPU训练的cudatensor,于是要下面这样转换一下
dict_trained = torch.load("mobilenet_sgd_rmsprop_69.526.tar",map_location=lambda storage, loc: storage)["state_dict"]
dict_new = net.state_dict().copy()
new_list = list (net.state_dict().keys() )
trained_list = list (dict_trained.keys() )
print("new_state_dict size: {} trained state_dict size: {}".format(len(new_list),len(trained_list)) )
print("New state_dict first 10th parameters names")
print(new_list[:10])
print("trained state_dict first 10th parameters names")
print(trained_list[:10])
print(type(dict_new))
print(type(dict_trained))

得到输出如下:

我们截断一半之后,参数由137变成65了,前十个参数看出,名字变了但是顺序其实没变。state_dict的数据类型是Odict,可以按照dict的操作方法操作。

new_state_dict size: 65 trained state_dict size: 137

New state_dict first 10th parameters names

[‘conv1.0.weight’, ‘conv1.1.weight’, ‘conv1.1.bias’, ‘conv1.1.running_mean’, ‘conv1.1.running_var’, ‘conv2.0.weight’, ‘conv2.1.weight’, ‘conv2.1.bias’, ‘conv2.1.running_mean’, ‘conv2.1.running_var’]

trained state_dict first 10th parameters names

[‘module.model.0.0.weight’, ‘module.model.0.1.weight’, ‘module.model.0.1.bias’, ‘module.model.0.1.running_mean’, ‘module.model.0.1.running_var’, ‘module.model.1.0.weight’, ‘module.model.1.1.weight’, ‘module.model.1.1.bias’, ‘module.model.1.1.running_mean’, ‘module.model.1.1.running_var’]

我们看出只要构建一个字典,使得字典的keys和我们自己创建的网络相同,我们在从各种预训练网络把想要的参数对着新的keys填进去就可以有一个新的state_dict了,这样我们就可以load这个新的state_dict,这是最普适的方法适用于所有的网络变化。

for i in range(65):
dict_new[ new_list[i] ] = dict_trained[ trained_list[i] ]
net.load_state_dict(dict_new)

还有别的情况,比如我们只是在后面加了一些层,没有改变原来网络层的名字和结构,可以用下面的简便方法:

loaded_dict = {k: loaded_dict[k] for k, _ in model.state_dict()}

二.冻结这几层参数

方法很多,这里用和上面方法对应的冻结方法

发现之前的冻结有问题,还是建议看一下
https://discuss.pytorch.org/t/how-the-pytorch-freeze-network-in-some-layers-only-the-rest-of-the-training/7088
或者
https://discuss.pytorch.org/t/correct-way-to-freeze-layers/26714
或者

对应的,在训练时候,optimizer里面只能更新requires_grad = True的参数,于是

optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, net.parameters(),lr) )


推荐阅读
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • 本文介绍了PhysioNet网站提供的生理信号处理工具箱WFDB Toolbox for Matlab的安装和使用方法。通过下载并添加到Matlab路径中或直接在Matlab中输入相关内容,即可完成安装。该工具箱提供了一系列函数,可以方便地处理生理信号数据。详细的安装和使用方法可以参考本文内容。 ... [详细]
  • 图解redis的持久化存储机制RDB和AOF的原理和优缺点
    本文通过图解的方式介绍了redis的持久化存储机制RDB和AOF的原理和优缺点。RDB是将redis内存中的数据保存为快照文件,恢复速度较快但不支持拉链式快照。AOF是将操作日志保存到磁盘,实时存储数据但恢复速度较慢。文章详细分析了两种机制的优缺点,帮助读者更好地理解redis的持久化存储策略。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • Go GUIlxn/walk 学习3.菜单栏和工具栏的具体实现
    本文介绍了使用Go语言的GUI库lxn/walk实现菜单栏和工具栏的具体方法,包括消息窗口的产生、文件放置动作响应和提示框的应用。部分代码来自上一篇博客和lxn/walk官方示例。文章提供了学习GUI开发的实际案例和代码示例。 ... [详细]
  • 海马s5近光灯能否直接更换为H7?
    本文主要介绍了海马s5车型的近光灯是否可以直接更换为H7灯泡,并提供了完整的教程下载地址。此外,还详细讲解了DSP功能函数中的数据拷贝、数据填充和浮点数转换为定点数的相关内容。 ... [详细]
  • 手把手教你使用GraphPad Prism和Excel绘制回归分析结果的森林图
    本文介绍了使用GraphPad Prism和Excel绘制回归分析结果的森林图的方法。通过展示森林图,可以更加直观地将回归分析结果可视化。GraphPad Prism是一款专门为医学专业人士设计的绘图软件,同时也兼顾统计分析的功能,操作便捷,可以帮助科研人员轻松绘制出高质量的专业图形。文章以一篇发表在JACC杂志上的研究为例,利用其中的多因素回归分析结果来绘制森林图。通过本文的指导,读者可以学会如何使用GraphPad Prism和Excel绘制回归分析结果的森林图。 ... [详细]
  • 大数据Hadoop生态(20)MapReduce框架原理OutputFormat的开发笔记
    本文介绍了大数据Hadoop生态(20)MapReduce框架原理OutputFormat的开发笔记,包括outputFormat接口实现类、自定义outputFormat步骤和案例。案例中将包含nty的日志输出到nty.log文件,其他日志输出到other.log文件。同时提供了一些相关网址供参考。 ... [详细]
  • 如何使用Python从工程图图像中提取底部的方法?
    本文介绍了使用Python从工程图图像中提取底部的方法。首先将输入图片转换为灰度图像,并进行高斯模糊和阈值处理。然后通过填充潜在的轮廓以及使用轮廓逼近和矩形核进行过滤,去除非矩形轮廓。最后通过查找轮廓并使用轮廓近似、宽高比和轮廓区域进行过滤,隔离所需的底部轮廓,并使用Numpy切片提取底部模板部分。 ... [详细]
  • 抽空写了一个ICON图标的转换程序
    抽空写了一个ICON图标的转换程序,支持png\jpe\bmp格式到ico的转换。具体的程序就在下面,如果看的人多,过两天再把思路写一下。 ... [详细]
  • VScode格式化文档换行或不换行的设置方法
    本文介绍了在VScode中设置格式化文档换行或不换行的方法,包括使用插件和修改settings.json文件的内容。详细步骤为:找到settings.json文件,将其中的代码替换为指定的代码。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • Android系统移植与调试之如何修改Android设备状态条上音量加减键在横竖屏切换的时候的显示于隐藏
    本文介绍了如何修改Android设备状态条上音量加减键在横竖屏切换时的显示与隐藏。通过修改系统文件system_bar.xml实现了该功能,并分享了解决思路和经验。 ... [详细]
  • 本文介绍了深入浅出Linux设备驱动编程的重要性,以及两种加载和删除Linux内核模块的方法。通过一个内核模块的例子,展示了模块的编译和加载过程,并讨论了模块对内核大小的控制。深入理解Linux设备驱动编程对于开发者来说非常重要。 ... [详细]
  • Vagrant虚拟化工具的安装和使用教程
    本文介绍了Vagrant虚拟化工具的安装和使用教程。首先介绍了安装virtualBox和Vagrant的步骤。然后详细说明了Vagrant的安装和使用方法,包括如何检查安装是否成功。最后介绍了下载虚拟机镜像的步骤,以及Vagrant镜像网站的相关信息。 ... [详细]
author-avatar
半暖半夏半流年
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有