热门标签 | HotTags
当前位置:  开发笔记 > 前端 > 正文

解决Pytorch半精度浮点型网络训练的问题

这篇文章主要介绍了解决Pytorch半精度浮点型网络训练的问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

用Pytorch1.0进行半精度浮点型网络训练需要注意下问题:

1、网络要在GPU上跑,模型和输入样本数据都要cuda().half()

2、模型参数转换为half型,不必索引到每层,直接model.cuda().half()即可

3、对于半精度模型,优化算法,Adam我在使用过程中,在某些参数的梯度为0的时候,更新权重后,梯度为零的权重变成了NAN,这非常奇怪,但是Adam算法对于全精度数据类型却没有这个问题。

另外,SGD算法对于半精度和全精度计算均没有问题。

还有一个问题是不知道是不是网络结构比较小的原因,使用半精度的训练速度还没有全精度快。这个值得后续进一步探索。

对于上面的这个问题,的确是网络很小的情况下,在1080Ti上半精度浮点型没有很明显的优势,但是当网络变大之后,半精度浮点型要比全精度浮点型要快。

但具体快多少和模型的大小以及输入样本大小有关系,我测试的是要快1/6,同时,半精度浮点型在占用内存上比较有优势,对于精度的影响尚未探究。

将网络再变大些,epoch的次数也增大,半精度和全精度的时间差就表现出来了,在训练的时候。

补充:pytorch半精度,混合精度,单精度训练的区别amp.initialize

看代码吧~

mixed_precision = True
try:  # Mixed precision training https://github.com/NVIDIA/apex
    from apex import amp
except:
    mixed_precision = False  # not installed

 model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=1)

为了帮助提高Pytorch的训练效率,英伟达提供了混合精度训练工具Apex。号称能够在不降低性能的情况下,将模型训练的速度提升2-4倍,训练显存消耗减少为之前的一半。

文档地址是:https://nvidia.github.io/apex/index.html

该 工具 提供了三个功能,amp、parallel和normalization。由于目前该工具还是0.1版本,功能还是很基础的,在最后一个normalization功能中只提供了LayerNorm层的复现,实际上在后续的使用过程中会发现,出现问题最多的是pytorch的BN层。

第二个工具是pytorch的分布式训练的复现,在文档中描述的是和pytorch中的实现等价,在代码中可以选择任意一个使用,实际使用过程中发现,在使用混合精度训练时,使用Apex复现的parallel工具,能避免一些bug。

默认训练方式是 单精度float32

import torch
model = torch.nn.Linear(D_in, D_out)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
 out = model(img)
 loss = LOSS(out, label)
 loss.backward()
 optimizer.step()
 optimizer.zero_grad()

半精度 model(img.half())

import torch
model = torch.nn.Linear(D_in, D_out).half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
 out = model(img.half())
 loss = LOSS(out, label)
 loss.backward()
 optimizer.step()
 optimizer.zero_grad()

接下来是混合精度的实现,这里主要用到Apex的amp工具。

代码修改为:

加上这一句封装,

model, optimizer = amp.initialize(model, optimizer, opt_level=“O1”)
import torch
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

for img, label in dataloader:
 out = model(img)
 loss = LOSS(out, label)
 # loss.backward()
 with amp.scale_loss(loss, optimizer) as scaled_loss:
     scaled_loss.backward()

 optimizer.step()
 optimizer.zero_grad()

实际流程为:调用amp.initialize按照预定的opt_level对model和optimizer进行设置。在计算loss时使用amp.scale_loss进行回传。

需要注意以下几点:

在调用amp.initialize之前,模型需要放在GPU上,也就是需要调用cuda()或者to()。

在调用amp.initialize之前,模型不能调用任何分布式设置函数。

此时输入数据不需要在转换为半精度。

在使用混合精度进行计算时,最关键的参数是opt_level。他一共含有四种设置值:‘00',‘01',‘02',‘03'。实际上整个amp.initialize的输入参数很多:

但是在实际使用过程中发现,设置opt_level即可,这也是文档中例子的使用方法,甚至在不同的opt_level设置条件下,其他的参数会变成无效。(已知BUG:使用‘01'时设置keep_batchnorm_fp32的值会报错)

概括起来:

00相当于原始的单精度训练。01在大部分计算时采用半精度,但是所有的模型参数依然保持单精度,对于少数单精度较好的计算(如softmax)依然保持单精度。02相比于01,将模型参数也变为半精度。

03基本等于最开始实验的全半精度的运算。值得一提的是,不论在优化过程中,模型是否采用半精度,保存下来的模型均为单精度模型,能够保证模型在其他应用中的正常使用。这也是Apex的一大卖点。

在Pytorch中,BN层分为train和eval两种操作。

实现时若为单精度网络,会调用CUDNN进行计算加速。常规训练过程中BN层会被设为train。Apex优化了这种情况,通过设置keep_batchnorm_fp32参数,能够保证此时BN层使用CUDNN进行计算,达到最好的计算速度。

但是在一些fine tunning场景下,BN层会被设为eval(我的模型就是这种情况)。此时keep_batchnorm_fp32的设置并不起作用,训练会产生数据类型不正确的bug。此时需要人为的将所有BN层设置为半精度,这样将不能使用CUDNN加速。

一个设置的参考代码如下:

def fix_bn(m):
 classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
     m.eval().half()

model.apply(fix_bn)

实际测试下来,最后的模型准确度上感觉差别不大,可能有轻微下降;时间上变化不大,这可能会因不同的模型有差别;显存开销上确实有很大的降低。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持。


推荐阅读
  • 【前端开发】深入探讨 RequireJS 与性能优化策略
    随着前端技术的迅速发展,RequireJS虽然不再像以往那样吸引关注,但其在模块化加载方面的优势仍然值得深入探讨。本文将详细介绍RequireJS的基本概念及其作为模块加载工具的核心功能,并重点分析其性能优化策略,帮助开发者更好地理解和应用这一工具,提升前端项目的加载速度和整体性能。 ... [详细]
  • HBase在金融大数据迁移中的应用与挑战
    随着最后一台设备的下线,标志着超过10PB的HBase数据迁移项目顺利完成。目前,新的集群已在新机房稳定运行超过两个月,监控数据显示,新集群的查询响应时间显著降低,系统稳定性大幅提升。此外,数据消费的波动也变得更加平滑,整体性能得到了显著优化。 ... [详细]
  • 2019年后蚂蚁集团与拼多多面试经验详述与深度剖析
    2019年后蚂蚁集团与拼多多面试经验详述与深度剖析 ... [详细]
  • 2019年斯坦福大学CS224n课程笔记:深度学习在自然语言处理中的应用——Word2Vec与GloVe模型解析
    本文详细解析了2019年斯坦福大学CS224n课程中关于深度学习在自然语言处理(NLP)领域的应用,重点探讨了Word2Vec和GloVe两种词嵌入模型的原理与实现方法。通过具体案例分析,深入阐述了这两种模型在提升NLP任务性能方面的优势与应用场景。 ... [详细]
  • 作为140字符的开创者,Twitter看似简单却异常复杂。其简洁之处在于仅用140个字符就能实现信息的高效传播,甚至在多次全球性事件中超越传统媒体的速度。然而,为了支持2亿用户的高效使用,其背后的技术架构和系统设计则极为复杂,涉及高并发处理、数据存储和实时传输等多个技术挑战。 ... [详细]
  • 本文深入探讨了 MXOTDLL.dll 在 C# 环境中的应用与优化策略。针对近期公司从某生物技术供应商采购的指纹识别设备,该设备提供的 DLL 文件是用 C 语言编写的。为了更好地集成到现有的 C# 系统中,我们对原生的 C 语言 DLL 进行了封装,并利用 C# 的互操作性功能实现了高效调用。此外,文章还详细分析了在实际应用中可能遇到的性能瓶颈,并提出了一系列优化措施,以确保系统的稳定性和高效运行。 ... [详细]
  • 智能制造数据综合分析与应用解决方案
    在智能制造领域,生产数据通过先进的采集设备收集,并利用时序数据库或关系型数据库进行高效存储。这些数据经过处理后,通过可视化数据大屏呈现,为生产车间、生产控制中心以及管理层提供实时、精准的信息支持,助力不同应用场景下的决策优化和效率提升。 ... [详细]
  • 深入解析PowerShell Stable与Preview版的区别与应用
    在研究 PowerShell 的 GitHub 发布页面时,我们注意到除了稳定的 6.2.0 版本外,还推出了 6.2.0 的预览版。本文将详细探讨这两个版本之间的主要区别及其应用场景,帮助用户更好地选择适合自身需求的版本。我们将分析稳定版的成熟度、性能和安全性,以及预览版的新功能、改进和潜在风险,为用户提供全面的参考信息。 ... [详细]
  • BZOJ1034 详细解析与算法优化
    本文深入解析了BZOJ1034问题,并提出了优化算法。通过借鉴广义田忌赛马的贪心策略,当己方当前最弱的马优于对方最弱的马时进行匹配;同样地,若己方当前最强的马优于对方最强的马,也进行匹配。此方法在保证胜率的同时,有效提升了算法效率。 ... [详细]
  • 负载均衡基础概念与技术解析
    随着互联网应用的不断扩展,用户流量激增,业务复杂度显著提升,单一服务器已难以应对日益增长的负载需求。负载均衡技术应运而生,通过将请求合理分配到多个服务器,有效提高系统的可用性和响应速度。本文将深入探讨负载均衡的基本概念和技术原理,分析其在现代互联网架构中的重要性及应用场景。 ... [详细]
  • 抖音AI特效风靡网络,真人瞬间变身动漫角色,吴亦凡、PDD和戚薇纷纷沉迷其中
    近期,抖音推出的一款名为“变身漫画”的AI特效在社交媒体上迅速走红,吸引了大量用户尝试。不仅普通网友积极参与,连吴亦凡、PDD和戚薇等明星也纷纷加入,体验将真人瞬间转化为动漫角色的神奇效果。这一特效凭借其高度的趣味性和创新性,迅速成为网络热议的话题。 ... [详细]
  • 题目描述:小K不幸被LL邪教洗脑,洗脑程度之深使他决定彻底脱离这个邪教。在最终离开前,他计划再进行一次亚瑟王游戏。作为最后一战,他希望这次游戏能够尽善尽美。众所周知,亚瑟王游戏的结果很大程度上取决于运气,但通过合理的策略和算法优化,可以提高获胜的概率。本文将详细解析洛谷P3239 [HNOI2015] 亚瑟王问题,并提供具体的算法实现方法,帮助读者更好地理解和应用相关技术。 ... [详细]
  • 利用PaddleSharp模块在C#中实现图像文字识别功能测试
    PaddleSharp 是 PaddleInferenceCAPI 的 C# 封装库,适用于 Windows (x64)、NVIDIA GPU 和 Linux (Ubuntu 20.04) 等平台。本文详细介绍了如何使用 PaddleSharp 在 C# 环境中实现图像文字识别功能,并进行了全面的功能测试,验证了其在多种硬件配置下的稳定性和准确性。 ... [详细]
  • 优化后的标题:PHP分布式高并发秒杀系统设计与实现
    PHPSeckill是一个基于PHP、Lua和Redis构建的高效分布式秒杀系统。该项目利用php_apcu扩展优化性能,实现了高并发环境下的秒杀功能。系统设计充分考虑了分布式架构的可扩展性和稳定性,适用于大规模用户同时访问的场景。项目代码已开源,可在Gitee平台上获取。 ... [详细]
  • 利用Redis HyperLogLog高效统计微博日活跃和月活跃用户数
    本文探讨了如何利用Redis的HyperLogLog数据结构高效地统计微博平台的日活跃用户(DAU)和月活跃用户(MAU)数量。通过HyperLogLog的高精度和低内存消耗特性,可以实现对大规模用户数据的实时统计与分析,为平台运营提供有力的数据支持。 ... [详细]
author-avatar
IP-COM东莞办事处_426
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有