热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorchcheckpoint_PyTorch|保存和加载模型

原题|SAVINGANDLOADINGMODELS作者|MatthewInkawhich原文|https:pytorch.orgtutorialsbeginnersaving_lo

原题 | SAVING AND LOADING MODELS

作者 | Matthew Inkawhich

原文 | https://pytorch.org/tutorials/beginner/saving_loading_models.html

声明 | 翻译是出于交流学习的目的,欢迎转载,但请保留本文出处,请勿用作商业或者非法用途

简介

本文主要介绍如何加载和保存 PyTorch 的模型。这里主要有三个核心函数:

  1. torch.save :把序列化的对象保存到硬盘。它利用了 Python 的 pickle 来实现序列化。模型、张量以及字典都可以用该函数进行保存;
  2. torch.load:采用 pickle 将反序列化的对象从存储中加载进来。
  3. torch.nn.Module.load_state_dict:采用一个反序列化的 state_dict加载一个模型的参数字典。

本文主要内容如下:

  • 什么是状态字典(state_dict)?
  • 预测时加载和保存模型
  • 加载和保存一个通用的检查点(Checkpoint)
  • 在同一个文件保存多个模型
  • 采用另一个模型的参数来预热模型(Warmstaring Model)
  • 不同设备下保存和加载模型

1. 什么是状态字典(state_dict)

PyTorch 中,一个模型(torch.nn.Module)的可学习参数(也就是权重和偏置值)是包含在模型参数(model.parameters())中的,一个状态字典就是一个简单的 Python 的字典,其键值对是每个网络层和其对应的参数张量。模型的状态字典只包含带有可学习参数的网络层(比如卷积层、全连接层等)和注册的缓存(batchnormrunning_mean)。优化器对象(torch.optim)同样也是有一个状态字典,包含的优化器状态的信息以及使用的超参数。

由于状态字典也是 Python 的字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都很容易实现。

下面是一个简单的使用例子,例子来自:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

# Define model
class TheModelClass(nn.Module):def __init__(self):super(TheModelClass, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# Initialize model
model = TheModelClass()# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():print(param_tensor, "t", model.state_dict()[param_tensor].size())# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():print(var_name, "t", optimizer.state_dict()[var_name])

上述代码先是简单定义一个 5 层的 CNN,然后分别打印模型的参数和优化器参数。

输出结果:

Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

2. 预测时加载和保存模型

加载/保存状态字典(推荐做法)

保存的代码:

torch.save(model.state_dict(), PATH)

加载的代码:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。采用 torch.save() 来保存模型的状态字典的做法可以更方便加载模型,这也是推荐这种做法的原因。

通常会用 .pt 或者 .pth 后缀来保存模型。

记住

  1. 在进行预测之前,必须调用 model.eval() 方法来将 dropoutbatch normalization 层设置为验证模型。否则,只会生成前后不一致的预测结果。
  2. load_state_dict() 方法必须传入一个字典对象,而不是对象的保存路径,也就是说必须先反序列化字典对象,然后再调用该方法,也是例子中先采用 torch.load() ,而不是直接 model.load_state_dict(PATH)

加载/保存整个模型

保存:

torch.save(model, PATH)

加载:

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

保存和加载模型都是采用非常直观的语法并且都只需要几行代码即可实现。这种实现保存模型的做法将是采用 Python 的 pickle 模块来保存整个模型,这种做法的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是 pickle 并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors 后采用都可能出现错误。

3. 加载和保存一个通用的检查点(Checkpoint)

保存的示例代码:

torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH)

加载的示例代码:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']model.eval()
# - or -
model.train()

当保存一个通用的检查点(checkpoint)时,无论是用于继续训练还是预测,都需要保存更多的信息,不仅仅是 state_dict ,比如说优化器的 state_dict 也是非常重要的,它包含了用于模型训练时需要更新的参数和缓存信息,还可以保存的信息包括 epoch,即中断训练的批次,最后一次的训练 loss,额外的 torch.nn.Embedding 层等等。

上述保存代码就是介绍了如何保存这么多种信息,通过用一个字典来进行组织,然后继续调用 torch.save 方法,一般保存的文件后缀名是 .tar

加载代码也如上述代码所示,首先需要初始化模型和优化器,然后加载模型时分别调用 torch.load 加载对应的 state_dict 。然后通过不同的键来获取对应的数值。

加载完后,根据后续步骤,调用 model.eval() 用于预测,model.train() 用于恢复训练。

4. 在同一个文件保存多个模型

保存模型的示例代码:

torch.save({'modelA_state_dict': modelA.state_dict(),'modelB_state_dict': modelB.state_dict(),'optimizerA_state_dict': optimizerA.state_dict(),'optimizerB_state_dict': optimizerB.state_dict(),...}, PATH)

加载模型的示例代码:

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

当我们希望保存的是一个包含多个网络模型 torch.nn.Modules 的时候,比如 GAN、一个序列化模型,或者多个模型融合,实现的方法其实和保存一个通用的检查点的做法是一样的,同样采用一个字典来保持模型的 state_dict 和对应优化器的 state_dict 。除此之外,还可以继续保存其他相同的信息。

加载模型的示例代码如上述所示,和加载一个通用的检查点也是一样的,同样需要先初始化对应的模型和优化器。同样,保存的模型文件通常是以 .tar 作为后缀名。

5. 采用另一个模型的参数来预热模型(Warmstaring Model)

保存模型的示例代码:

torch.save(modelA.state_dict(), PATH)

加载模型的示例代码:

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

在之前迁移学习教程中也介绍了可以通过预训练模型来微调,加快模型训练速度和提高模型的精度。

这种做法通常是加载预训练模型的部分网络参数作为模型的初始化参数,然后可以加快模型的收敛速度。

加载预训练模型的代码如上述所示,其中设置参数 strict=False 表示忽略不匹配的网络层参数,因为通常我们都不会完全采用和预训练模型完全一样的网络,通常输出层的参数就会不一样。

当然,如果希望加载参数名不一样的参数,可以通过修改加载的模型对应的参数名字,这样参数名字匹配了就可以成功加载。

6. 不同设备下保存和加载模型

在GPU上保存模型,在 CPU 上加载模型

保存模型的示例代码:

torch.save(model.state_dict(), PATH)

加载模型的示例代码:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

在 CPU 上加载在 GPU 上训练的模型,必须在调用 torch.load() 的时候,设置参数 map_location ,指定采用的设备是 torch.device('cpu'),这个做法会将张量都重新映射到 CPU 上。

在GPU上保存模型,在 GPU 上加载模型

保存模型的示例代码:

torch.save(model.state_dict(), PATH)

加载模型的示例代码:

device = torch.device('cuda')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH)
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

在 GPU 上训练和加载模型,调用 torch.load() 加载模型后,还需要采用 model.to(torch.device('cuda')),将模型调用到 GPU 上,并且后续输入的张量都需要确保是在 GPU 上使用的,即也需要采用 my_tensor.to(device)

在CPU上保存,在GPU上加载模型

保存模型的示例代码:

torch.save(model.state_dict(), PATH)

加载模型的示例代码:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

这次是 CPU 上训练模型,但在 GPU 上加载模型使用,那么就需要通过参数 map_location 指定设备。然后继续记得调用 model.to(torch.device('cuda'))

保存 torch.nn.DataParallel 模型

保存模型的示例代码:

torch.save(model.module.state_dict(), PATH)

torch.nn.DataParallel 是用于实现多 GPU 并行的操作,保存模型的时候,是采用 model.module.state_dict()

加载模型的代码也是一样的,采用 torch.load() ,并可以放到指定的 GPU 显卡上。


完整的代码:

https://github.com/pytorch/tutorials/blob/master/beginner_source/saving_loading_models.py



推荐阅读
  • 本文介绍了数据库的存储结构及其重要性,强调了关系数据库范例中将逻辑存储与物理存储分开的必要性。通过逻辑结构和物理结构的分离,可以实现对物理存储的重新组织和数据库的迁移,而应用程序不会察觉到任何更改。文章还展示了Oracle数据库的逻辑结构和物理结构,并介绍了表空间的概念和作用。 ... [详细]
  • 本文介绍了Android 7的学习笔记总结,包括最新的移动架构视频、大厂安卓面试真题和项目实战源码讲义。同时还分享了开源的完整内容,并提醒读者在使用FileProvider适配时要注意不同模块的AndroidManfiest.xml中配置的xml文件名必须不同,否则会出现问题。 ... [详细]
  • 深度学习中的Vision Transformer (ViT)详解
    本文详细介绍了深度学习中的Vision Transformer (ViT)方法。首先介绍了相关工作和ViT的基本原理,包括图像块嵌入、可学习的嵌入、位置嵌入和Transformer编码器等。接着讨论了ViT的张量维度变化、归纳偏置与混合架构、微调及更高分辨率等方面。最后给出了实验结果和相关代码的链接。本文的研究表明,对于CV任务,直接应用纯Transformer架构于图像块序列是可行的,无需依赖于卷积网络。 ... [详细]
  • 海马s5近光灯能否直接更换为H7?
    本文主要介绍了海马s5车型的近光灯是否可以直接更换为H7灯泡,并提供了完整的教程下载地址。此外,还详细讲解了DSP功能函数中的数据拷贝、数据填充和浮点数转换为定点数的相关内容。 ... [详细]
  • 本文分析了Wince程序内存和存储内存的分布及作用。Wince内存包括系统内存、对象存储和程序内存,其中系统内存占用了一部分SDRAM,而剩下的30M为程序内存和存储内存。对象存储是嵌入式wince操作系统中的一个新概念,常用于消费电子设备中。此外,文章还介绍了主电源和后备电池在操作系统中的作用。 ... [详细]
  • 本文概述了JNI的原理以及常用方法。JNI提供了一种Java字节码调用C/C++的解决方案,但引用类型不能直接在Native层使用,需要进行类型转化。多维数组(包括二维数组)都是引用类型,需要使用jobjectArray类型来存取其值。此外,由于Java支持函数重载,根据函数名无法找到对应的JNI函数,因此介绍了JNI函数签名信息的解决方案。 ... [详细]
  • Mono为何能跨平台
    概念JIT编译(JITcompilation),运行时需要代码时,将Microsoft中间语言(MSIL)转换为机器码的编译。CLR(CommonLa ... [详细]
  • 原文地址http://balau82.wordpress.com/2010/02/28/hello-world-for-bare-metal-arm-using-qemu/最开始时 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • 本文介绍了九度OnlineJudge中的1002题目“Grading”的解决方法。该题目要求设计一个公平的评分过程,将每个考题分配给3个独立的专家,如果他们的评分不一致,则需要请一位裁判做出最终决定。文章详细描述了评分规则,并给出了解决该问题的程序。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • HDFS2.x新特性
    一、集群间数据拷贝scp实现两个远程主机之间的文件复制scp-rhello.txtroothadoop103:useratguiguhello.txt推pushscp-rr ... [详细]
  • 本文介绍了Swing组件的用法,重点讲解了图标接口的定义和创建方法。图标接口用来将图标与各种组件相关联,可以是简单的绘画或使用磁盘上的GIF格式图像。文章详细介绍了图标接口的属性和绘制方法,并给出了一个菱形图标的实现示例。该示例可以配置图标的尺寸、颜色和填充状态。 ... [详细]
  • x86 linux的进程调度,x86体系结构下Linux2.6.26的进程调度和切换
    进程调度相关数据结构task_structtask_struct是进程在内核中对应的数据结构,它标识了进程的状态等各项信息。其中有一项thread_struct结构的 ... [详细]
  • ihaveusedthedelphidatabindingwizardwithmyxmlfile,andeverythingcompilesandrunsfine. ... [详细]
author-avatar
____L振豪
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有