热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch模型训练中实现CPU与GPU的高效切换方法

1.如何进行迁移使用Pytorch写的模型:对模型和相应的数据使用.cuda()处理。通过这种方式,我们就可以将内存中的数据复制到GPU的显存中去。

1.如何进行迁移

使用Pytorch写的模型:


  • 对模型和相应的数据使用.cuda()处理。通过这种方式,我们就可以将内存中的数据复制到GPU的显存中去。从而可以通过GPU来进行运算了。
  • 另外一种方式,使用.to(device)的方式,将cpu的数据切换到gpu,如下:

#配置参数:config.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(config.device)

2.对数据的迁移

.cuda() 操作默认使用GPU 0也就是第一张显卡来进行操作。当我们想要存储在其他显卡中时可以使用 .cuda(<显卡号数>) 来将数据存储在指定的显卡中。还有很多种方式&#xff0c;具体参考官方文档。

对于不同存储位置的变量&#xff0c;我们是不可以对他们直接进行计算的。存储在不同位置中的数据是不可以直接进行交互计算的。

换句话说也就是上面例子中的 torch.FloatTensor 是不可以直接与 torch.cuda.FloatTensor 进行基本运算的。位于不同GPU显存上的数据也是不能直接进行计算的。

对于Variable&#xff0c;其实就仅仅是一种能够记录操作信息并且能够自动求导的容器&#xff0c;实际上的关键信息并不在Variable本身&#xff0c;而更应该侧重于Variable中存储的data。

这里举一个例子&#xff0c;训练的时候&#xff0c;怎么在epoch中&#xff0c;将数据从cpu转到gpu&#xff1a;

for epoch in range(config.num_epochs):print(&#39;Epoch [{}/{}]&#39;.format(epoch &#43; 1, config.num_epochs))total_eval_accuracy &#61; 0total_loss &#61; 0for step, batch in enumerate(train_dataloader):#重点的两句话&#xff0c;batch[0]是训练数据&#xff0c;batch[1]是训练数据的labelbatch[0] &#61; torch.LongTensor(batch[0]).to(config.device)batch[1] &#61; torch.LongTensor(batch[1]).to(config.device)

3.模型迁移

一行代码&#xff1a;

#config是配置文件&#xff0c;里面包含了设备信息&#xff0c;模型参数等&#xff0c;大致理解意思就好&#xff0c;不要在乎config里面具体是什么。
model &#61; Classifier.nn(config.para)
model &#61; model.to(config.device)

4.汇总

在代码中使用GPU训练主要有三处需要注意&#xff1a;模型转为cuda&#xff0c;数据转为cuda&#xff0c;和输出数据去cuda&#xff0c;转为numpy。修改的地方包括将数据的形式变成 GPU 能读的形式, 然后将 网络模型 也变成 GPU 能读的形式。

模型训练时:如果数据放在了GPU上&#xff0c;那么模型也要转到GPU上。

模型预测时:计算预测的acc、auc这类型的评估参数时&#xff0c;实在cpu上进行的&#xff0c;所以模型evaluate时&#xff0c;需要将loss之类的转到cpu上&#xff0c;例子如下:

labels &#61; labels.data.cpu().numpy()predic &#61; torch.max(logits, 1)[1].cpu().numpy()labels_all &#61; np.append(labels_all, labels)predict_all &#61; np.append(predict_all, predic)acc &#61; metrics.accuracy_score(labels_all, predict_all)

转换时常见错误


1.RuntimeError: Input, output and indices must be on the current device

如果你的数据和模型没有同时在gpu或者cpu上&#xff0c;训练模型时&#xff0c;会报错如下&#xff0c;意思是输入和输出需要在同一设备上。

RuntimeError: Input, output and indices must be on the current device

解决方法&#xff1a;

将数据和模型放在同一设备即可。


2.AttributeError: &#39;list&#39; object has no attribute &#39;cuda&#39;

没搞清楚数据是不是tensor&#xff0c;就转到gpu时&#xff0c;会报出这个错误。

解决方法&#xff1a;先转为tensor&#xff0c;再转到gpu。

例子&#xff1a;

查看自己的target类型&#xff0c;原为[&#39;1&#39;,&#39;0&#39;,&#39;1&#39;,&#39;1&#39;]。这种列表试字符串型。而应该修改为torch.tensor类型。才能用于网络计算

简单改为&#xff1a;先改为numpy再转换为tensor&#xff0c;搞定&#xff01;

label &#61; torch.from_numpy(np.fromstring(label, dtype&#61;int, sep&#61;&#39;,&#39;))

 

参考&#xff1a;
1.https://zhuanlan.zhihu.com/p/31936740

2.https://blog.csdn.net/qq_21578849/article/details/85240797

3.训练demo&#xff1a;https://blog.csdn.net/WeDon_t/article/details/104300877?utm_medium&#61;distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-3.channel_param&depth_1-utm_source&#61;distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-3.channel_param

4.训练常见错误&#xff1a;https://blog.csdn.net/u014264373/article/details/87640753

 


推荐阅读
  • 尽管使用TensorFlow和PyTorch等成熟框架可以显著降低实现递归神经网络(RNN)的门槛,但对于初学者来说,理解其底层原理至关重要。本文将引导您使用NumPy从头构建一个用于自然语言处理(NLP)的RNN模型。 ... [详细]
  • 扫描线三巨头 hdu1928hdu 1255  hdu 1542 [POJ 1151]
    学习链接:http:blog.csdn.netlwt36articledetails48908031学习扫描线主要学习的是一种扫描的思想,后期可以求解很 ... [详细]
  • 本文详细介绍了Java中org.w3c.dom.Text类的splitText()方法,通过多个代码示例展示了其实际应用。该方法用于将文本节点在指定位置拆分为两个节点,并保持在文档树中。 ... [详细]
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
  • dotnet 通过 Elmish.WPF 使用 F# 编写 WPF 应用
    本文来安利大家一个有趣而且强大的库,通过F#和C#混合编程编写WPF应用,可以在WPF中使用到F#强大的数据处理能力在GitHub上完全开源Elmis ... [详细]
  • 本文详细介绍了Java中org.neo4j.helpers.collection.Iterators.single()方法的功能、使用场景及代码示例,帮助开发者更好地理解和应用该方法。 ... [详细]
  • 本文详细介绍了如何使用 Yii2 的 GridView 组件在列表页面实现数据的直接编辑功能。通过具体的代码示例和步骤,帮助开发者快速掌握这一实用技巧。 ... [详细]
  • 本文详细介绍了如何在 Spring Boot 应用中通过 @PropertySource 注解读取非默认配置文件,包括配置文件的创建、映射类的设计以及确保 Spring 容器能够正确加载这些配置的方法。 ... [详细]
  • 本文介绍如何使用 NSTimer 实现倒计时功能,详细讲解了初始化方法、参数配置以及具体实现步骤。通过示例代码展示如何创建和管理定时器,确保在指定时间间隔内执行特定任务。 ... [详细]
  • 从 .NET 转 Java 的自学之路:IO 流基础篇
    本文详细介绍了 Java 中的 IO 流,包括字节流和字符流的基本概念及其操作方式。探讨了如何处理不同类型的文件数据,并结合编码机制确保字符数据的正确读写。同时,文中还涵盖了装饰设计模式的应用,以及多种常见的 IO 操作实例。 ... [详细]
  • 本文介绍了如何通过 Maven 依赖引入 SQLiteJDBC 和 HikariCP 包,从而在 Java 应用中高效地连接和操作 SQLite 数据库。文章提供了详细的代码示例,并解释了每个步骤的实现细节。 ... [详细]
  • ASP.NET MVC中Area机制的实现与优化
    本文探讨了在ASP.NET MVC框架中,如何通过Area机制有效地组织和管理大规模应用程序的不同功能模块。通过合理的文件夹结构和命名规则,开发人员可以更高效地管理和扩展项目。 ... [详细]
  • 基于KVM的SRIOV直通配置及性能测试
    SRIOV介绍、VF直通配置,以及包转发率性能测试小慢哥的原创文章,欢迎转载目录?1.SRIOV介绍?2.环境说明?3.开启SRIOV?4.生成VF?5.VF ... [详细]
  • 本文提供了一系列Python编程基础练习题,涵盖了列表操作、循环结构、字符串处理和元组特性等内容。通过这些练习题,读者可以巩固对Python语言的理解并提升编程技能。 ... [详细]
  • 深入理解Redis的数据结构与对象系统
    本文详细探讨了Redis中的数据结构和对象系统的实现,包括字符串、列表、集合、哈希表和有序集合等五种核心对象类型,以及它们所使用的底层数据结构。通过分析源码和相关文献,帮助读者更好地理解Redis的设计原理。 ... [详细]
author-avatar
晓玲建雯东佳
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有