热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch0.40迁移指南(非正式官翻文档)

写在前面:本次更新最大亮点就是支持Windows啦,这对于初学者来说是件大喜事,不用再去折腾安装学习Linux系统就能正儿八经地搞深度学习了,新特性网上到处都是,我就不赘述了。0.

写在前面:本次更新最大亮点就是支持Windows啦,这对于初学者来说是件大喜事,不用再去折腾安装学习Linux系统就能正儿八经地搞深度学习了,新特性网上到处都是,我就不赘述了。0.4版本在函数接口上与之前版本还是有一些不同的,私以为最主要的还是合并了Tensor与Variable,还有就是对数据模型迁移方式的更改。下面我根据个人理解翻译了官方给的Migration Guide,愿为pytorch推广贡献一点自己的力量,理解不对的地方烦请各位指出批评,谢谢

PyTorch 0.40 迁移指南

欢迎阅读本指南。在这个版本中pytorch推出了很多新特性并修复了原来的BUG,给用户提供了更为便捷的函数接口。在本指南中,我们只挑重点来讲,告诉大家如何将原来的代码迁移到新的版本,以下是主要的更新特性:

  • TensorVariable 合并
  • 支持0维向量(标量)Tensor
  • 弃坑Volatile方式
  • 全新Tensor定义方式
  • 指定计算设备的函数更方便了

合并Tensor和Variable

torch.tensortorch.autograd.Variable现在合为一类了,更准确地讲,torch.tensor具备原来Variable的全部功能。现在Variable还能用,但返回的也是torch.Tensor类型。这意味着以后没必要使用Variable包裹Tensor数据了。

Tensor.type()变更

使用type()不再返回数据类型(float,double…)了。使用isinstance()x.type()可以查看其具体数据类型。

>>>x=torch.DoubleTensor([1,1,1])
>>>print(type())#返回所属类
""
>>>print(x.type())
"torch.DoubleTensor"
>>>print(isinstance(x,torch.DoubleTensor))
True

autograd现在是怎样工作的?

requires_grad曾是autograd的关键选项,现在被迁移到Tensor的属性,用法和之前的一样。当设置requires_grad=True时,autograd开始自动记录差分值。例如:

>>>x=torch.ones(1)
>>>x.requires_grad
False
>>>y=torch.ones(1)
>>>z=x+y
>>>z.requires_grad
>>>z.backward()
RuntimeError:element 0 of tensors does not require grad and does not have a grad_fn
>>>w=torch.ones(1,requires_grap=True)
>>>w.requires_grad
True
>>>total=w+z
>>>total.requires_grad
True
>>>total.backward()
>>>w.grad
tensor([1.])
>>>z.grad==x.grad==y.grad==None
True

设置requires_grad
除了在初始化的时候设置外,还可以使用my_tensor.requires_grad()来设置

.data用法

在之前的版本,使用.data将Variable转化为Tensor。现在合并之后,调用y=x.data 后,y是另一个Tensor,只是与x共用数据部分,但计算的求导信息不会记载到x 中。 但是,在某些情况使用.data欠妥。任何x.data的数据变化都不会影响到x的梯度。更为保险的方法是使用x.detach(),它返回的也是与原变量共享数据的Tensor 也不会影响计算的梯度信息,但是它会有梯度变化的报告信息

支持0维Tensor(标量)

之前版本中,求一维Tensor的索引返回一个数值,但是一维Variable却返回(1,)!相似的情况同样出现在求和函数中,例如tensor.sum()返回一个数值,然而Variable.sum()返回的是(1,) 还好,本次更新后pytorch支持标量了!标量可以直接用torch.tensor创建就像numpy.array那样

>>>torch.tensor(3.1416)
tensor(3.1416)
>>>torch.tensor(3.1416).size()
torch.Size([])#表明这是0维数据,即标量
>>>torch.tensor([3]).size()
torch.Size([1])
>>>vector=torch.arange(2,6)
>>>vector
tensor([2.,3.,4.,5.])
>>>vector.size()
torch.Size([4])
>>>vector[3].item()
5.0
>>>mysum=torch.tensor([2,3]).sum()
>>>mysum
tensor(5)
>>>mysum.size()
torch.Size([])

个人理解:新版本支持标量了,可以直接用,不像原来单个数据还给搞出个一维数组

损失积累
之前都使用total_loss+=loss.data[0]累积损失率。在0.4版本中有0维的标量了,直接用loss.item()得到其loss的数值就可以了。

反对使用volatile选项

volatile选项在0.4版本中不推荐使用了,之前版本中给变量设置volatile=True一遍其不求导计算。现在这个功能被其他函数替代 torch.no_grad(),torch.set_grad_enabled()

>>>x=torch.zeros(1,requires_grad=True)
>>>with torch.no_grad():
y=x*2
>>>y.requires_grad
False
>>>is_train=False
>>>with torch.set_grad_enabled(is_train):
y=x*2
>>>y.requires_grad
True
>>>torch.set_grad_enabled(False)
>>>y=x*2
>>>y.requires_grad
False

dtypes,devices变更

在0.40版本中,使用torch.dtype,torch.devicetorch.layout类来分配管理数据设备类型

torch.dtype

以下是可用的数据类型表和它相应的tensor类型

《PyTorch 0.40 迁移指南(非正式官翻文档)》
《PyTorch 0.40 迁移指南(非正式官翻文档)》

torch.device

torch.device包含两种设备类型,cpu和cuda。对于GPU还可以选择设备编号,例如torch.device(‘{设备类型}:{设备编号}’),如果不确定设备编号,默认使用torch.device('cuda')就会默认调用当前的显卡。可以使用torch.cuda.current_device()查看当前显卡

torch.layout

torch.layout代表tensor数据配置

创建Tensor

在新版本中创建Tensor需要考虑dtype,device,layout和requires_grad,例如

>>>device=torch.device('cuda:1')
>>>x-torch.randn(3,3,dtype=torch.float64,device=device)
tensor([-0.6344,0.8534,-1.2354],
[0.8414,1.7962,1.0589],
[-0.1369,-1.0462,-0.4373],dtype=torch.float64,device='cuda:1')
>>>x.requires_grad
False
>>>x=torch.zeros(3,requires_grad=True)
>>>x.requires_grad
True

torch.tensor(data,…)

torch.tensor()就像numpy.array()构造器,可以将数组类数据直接转换为Tensor,本版本中这个函数也可以构造标量。如果初始化没有指定dtype数据类型,pytorch将自动分配合适类型,我们极力推荐使用这种方法将已有的数据类(例如list)转化为Tensor

>>>cuda=torch.device('cuda')
>>>torch.tensor([[1],[2],[3]],dtype=torch.half,device=cuda)
tensor([[1],
[2],
[3]],device='cuda:0')
>>>torch.tensor(1)
tensor(1)
>>>torch.tensor([1,2,3]).dtype
torch.float32
>>>torch.tensor([1,2]).dtype
torch.int64

下面介绍其他创建Tensor的方法:

  • torch.*_like接受Tensor数据(注意不是数据的尺寸),如果不设置相关参数,它默认返回一个具有相同属性的Tensor

>>>x=torch.randn(3,dtype=torch.float64)
>>>torch.zeros_like(x)
tensor([0.,0.,0.],dtype=torch.float64)
>>>torch.zeros_like(x,dtype=torch.int)
tensor([0,0,0],dtype=torch.int32)

  • tensor.new_*使用尺寸作为参数创建具有相同属性的Tensor

>>>x=torch.randn(3,dtype=torch.float64)
>>>x.new_ones(2)
tensor([1.,1.],dtype=torch.float64)
>>>x.new_ones(4,dtype=torch.int)
tensor([1,1,1,1],dtype=torch.int32)

如果需要创建指定尺寸的Tensor,可以直接用元组指定尺寸作为参数,例如torch.zeros((2,3))torch.zeros(2,3),这样就能创建尺寸为2×3,元素为0的Tensor啦。

《PyTorch 0.40 迁移指南(非正式官翻文档)》
《PyTorch 0.40 迁移指南(非正式官翻文档)》

注意: torch.from_numpy()只能接受Numpy的ndarray作为参数输入

模型数据迁移

在之前的版本中,当不确定计算设备(cpu or which GPU?)情况时不太好写代码。 0.4版本做出了如下更新

  • 使用to方法可以轻松转换训练的网络(module)和数据到不同计算设备运行
  • device属性用来指定使用的计算设备,之前要用cpu(),cuda()转换模型或数据 示例demo:

device=torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
input=data.to(device)#直接指定数据到哪个设备中
model=MyModule().to(device)#同样,网络模型转换到指定设备中

例程demo

对比了0.310.4的代码

  • 0.31(old)

model=MyRNN()
if use_cuda:
model=model.cuda()
total_loss=0
for input,target in train_loader:
input,target=Variable(input),Variable(target)
hidden=Variable(torch.zeros(*h_shape))#隐藏层
if use_cuda:
input,target,hidden=input.cuda(),target.cuda(),hidden.cuda()
total_loss+=loss.data[0]
for input,target in test_loader:
input=Variable(input,volatile=True)
if use_cuda:
...
...

  • 0.40(new)

device=torch.device('cuda' if use_cuda else 'cpu')
model=MyRNN().to(device)
total_loss=0
for input,target in train_loader:
input,target=input.to(device),target.to(device)
hidden=input.new_zeros(*h_shape)
...
total_loss+=loss.item()#得到1维张量
with torch.no_grad():#不计算梯度
for input,target in test_loader:
...

最后附上英文原版连接:PyTorch 0.4.0 Migration Guide


推荐阅读
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • 十大经典排序算法动图演示+Python实现
    本文介绍了十大经典排序算法的原理、演示和Python实现。排序算法分为内部排序和外部排序,常见的内部排序算法有插入排序、希尔排序、选择排序、冒泡排序、归并排序、快速排序、堆排序、基数排序等。文章还解释了时间复杂度和稳定性的概念,并提供了相关的名词解释。 ... [详细]
  • 本文由编程笔记#小编为大家整理,主要介绍了logistic回归(线性和非线性)相关的知识,包括线性logistic回归的代码和数据集的分布情况。希望对你有一定的参考价值。 ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • 本文介绍了iOS数据库Sqlite的SQL语句分类和常见约束关键字。SQL语句分为DDL、DML和DQL三种类型,其中DDL语句用于定义、删除和修改数据表,关键字包括create、drop和alter。常见约束关键字包括if not exists、if exists、primary key、autoincrement、not null和default。此外,还介绍了常见的数据库数据类型,包括integer、text和real。 ... [详细]
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
  • 上图是InnoDB存储引擎的结构。1、缓冲池InnoDB存储引擎是基于磁盘存储的,并将其中的记录按照页的方式进行管理。因此可以看作是基于磁盘的数据库系统。在数据库系统中,由于CPU速度 ... [详细]
  • Linux服务器密码过期策略、登录次数限制、私钥登录等配置方法
    本文介绍了在Linux服务器上进行密码过期策略、登录次数限制、私钥登录等配置的方法。通过修改配置文件中的参数,可以设置密码的有效期、最小间隔时间、最小长度,并在密码过期前进行提示。同时还介绍了如何进行公钥登录和修改默认账户用户名的操作。详细步骤和注意事项可参考本文内容。 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • IB 物理真题解析:比潜热、理想气体的应用
    本文是对2017年IB物理试卷paper 2中一道涉及比潜热、理想气体和功率的大题进行解析。题目涉及液氧蒸发成氧气的过程,讲解了液氧和氧气分子的结构以及蒸发后分子之间的作用力变化。同时,文章也给出了解题技巧,建议根据得分点的数量来合理分配答题时间。最后,文章提供了答案解析,标注了每个得分点的位置。 ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • 第四章高阶函数(参数传递、高阶函数、lambda表达式)(python进阶)的讲解和应用
    本文主要讲解了第四章高阶函数(参数传递、高阶函数、lambda表达式)的相关知识,包括函数参数传递机制和赋值机制、引用传递的概念和应用、默认参数的定义和使用等内容。同时介绍了高阶函数和lambda表达式的概念,并给出了一些实例代码进行演示。对于想要进一步提升python编程能力的读者来说,本文将是一个不错的学习资料。 ... [详细]
  • 基于dlib的人脸68特征点提取(眨眼张嘴检测)python版本
    文章目录引言开发环境和库流程设计张嘴和闭眼的检测引言(1)利用Dlib官方训练好的模型“shape_predictor_68_face_landmarks.dat”进行68个点标定 ... [详细]
  • 本文讨论了编写可保护的代码的重要性,包括提高代码的可读性、可调试性和直观性。同时介绍了优化代码的方法,如代码格式化、解释函数和提炼函数等。还提到了一些常见的坏代码味道,如不规范的命名、重复代码、过长的函数和参数列表等。最后,介绍了如何处理数据泥团和进行函数重构,以提高代码质量和可维护性。 ... [详细]
  • Learning to Paint with Model-based Deep Reinforcement Learning
    本文介绍了一种基于模型的深度强化学习方法,通过结合神经渲染器,教机器像人类画家一样进行绘画。该方法能够生成笔画的坐标点、半径、透明度、颜色值等,以生成类似于给定目标图像的绘画。文章还讨论了该方法面临的挑战,包括绘制纹理丰富的图像等。通过对比实验的结果,作者证明了基于模型的深度强化学习方法相对于基于模型的DDPG和模型无关的DDPG方法的优势。该研究对于深度强化学习在绘画领域的应用具有重要意义。 ... [详细]
author-avatar
一个幼儿女教师上
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有