热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch数据集类和数据加载类的一些尝试

最近在学习PyTorch, 但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自

最近在学习PyTorch,  但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实验,小尝试。

 

 

下面给出一个常用的数据类使用方式:

def data_tf(x):
    x = np.array(x, dtype='float32') / 255 # 将数据变到 0 ~ 1 之间
    x = (x - 0.5) / 0.5 # 标准化,这个技巧之后会讲到
    x = x.reshape((-1,)) # 拉平
    x = torch.from_numpy(x)
    return x



from torchvision.datasets import MNIST # 导入 pytorch 内置的 mnist 数据
train_set = MNIST('./data', train=True, transform=data_tf, download=True) # 载入数据集,申明定义的数据变换
test_set = MNIST('./data', train=False, transform=data_tf, download=True)

 

 

其中,  data_tf  并不是必须要有的,比如:

from torchvision.datasets import MNIST # 导入 pytorch 内置的 mnist 数据
train_set = MNIST('./data', train=True, download=True) # 载入数据集,申明定义的数据变换
test_set = MNIST('./data', train=False, download=True)

这里面的MNIST类是框架自带的,可以自动下载MNIST数据库,   ./data  是指将下载的数据集存放在当前目录下的哪个目录下,    train 这个属性 True时 则在 ./data文件夹下面在建立一个 train的文件夹然后把下载的数据存放在其中,  当train属性是False的时候则把下载的数据放在 test文件夹下面。   

划线部分是老版本的PyTorch的处理方式,  最近试了一下最新版本  PyTorch 1.0   ,   train为True的时候是把数据放在  ./data/processed  文件夹下面, 命名为training.pt  ,  为False 的时候则放在  ./data/processed  文件夹下面, 命名为test.pt  。

 

PyTorch   数据集类  和   数据加载类     的一些尝试

 

PyTorch   数据集类  和   数据加载类     的一些尝试

 

 

 

 

这时候就出现了一个问题, 如果你使用的数据集不是框架自带的那么如何使用数据类呢,这个时候就要使用  pytorch 中的  Dataset 类了。

from torch.utils.data import Dataset

我们需要重写   Dataset类, 需要实现的方法为  __len__   和   __getitem__    这两个内置方法,  这里可以看出其思想就是要重写的类需要支持按照索引查找的方法。

 

 

 

 

这里我们还是举个例子:

PyTorch   数据集类  和   数据加载类     的一些尝试

 

 

PyTorch   数据集类  和   数据加载类     的一些尝试

 

PyTorch   数据集类  和   数据加载类     的一些尝试

 

PyTorch   数据集类  和   数据加载类     的一些尝试

从这个例子可以看出  mydataset就是我们自定义的 myDataset 类生成的自定义数据类对象。我们可以在myDataset类中自定义一些方法来对需要的数据进行处理。

为说明该问题另附加一个例子:

from torch.utils.data import Dataset


#需要在pytorch中使用的数据
data=[[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3], [5.1, 5.2, 5.3]]


class myDataset(Dataset):
    def __init__(self, indata):
        self.data=indata
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]


mydataset=myDataset(data)

 

 

那么又来了一个问题,我们不重写 Dataset类的话可不可以呢, 经过尝试发现还真可以,如下:

 PyTorch   数据集类  和   数据加载类     的一些尝试

 

 

 

又如:

PyTorch   数据集类  和   数据加载类     的一些尝试

 

PyTorch   数据集类  和   数据加载类     的一些尝试

 

PyTorch   数据集类  和   数据加载类     的一些尝试

 

PyTorch   数据集类  和   数据加载类     的一些尝试

由这个例子可以看出数据类对象可以不重写Dataset类, 只要具备  __len__      __getitem__    方法就可以。而且从这个例子我们可以看出  DataLoader  是一个迭代器, 如果shuffle 设置为 True 那么在每次迭代之前都会重新排序。

同时由上面两个例子可以看出  DataLoader类会把传入的数据集合中的数据转化为  torch.tensor 类型, 当然是采用默认的  DataLoader类中转化函数 transform的情况下。

这也就是说  DataLoader 默认的转化函数 transform操作为    传入的[ [x, x, x], [y, y, y] ] 输出的是 [ tensor([x, x, x]),  tensor([y, y, y]) ] ,

传入的是  tensor([ [x, x, x], [y, y, y] ]) 输出的是 tensor([ tensor([x, x, x]),  tensor([y, y, y]) ] ),   (这个例子是在   batch_size=2 的情况)。

 

 

 

综上,可知  其实   Dataset类, 和 DataLoader类其实在pytorch 计算过程中都不是一定要有的,  其中Dataset类是起一个规范作用,意义在于要人们对不同的类型数据做一些初步的调整,使其支持按照索引读取,以使其可以在 DataLoader中使用。

DataLoader 是一个迭代器, 可以方便的通过设置 batch_size 来实现 batch过程,transform则是对数据的一些处理。

 

 

 

 

---------------------------------------------------------------------------------------------------

 

上述内容更正:

 

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


#需要在pytorch中使用的数据
data=[[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3], [5.1, 5.2, 5.3]]

class myDataset(Dataset):
    def __init__(self, indata):
        self.data=indata
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]


mydataset=myDataset(data)
train_data=DataLoader(mydataset, batch_size=3, shuffle=True)

print("上文的错误操作:")

for i in train_data:
    print(i)
    print('-'*30)
print('again')
for i in train_data:
    print(i)
    print('-'*30)


#########################################


data=np.array(data)
data=torch.from_numpy(data)


mydataset=myDataset(data)
train_data=DataLoader(mydataset, batch_size=3, shuffle=True)


print("修正后的正确操作:")

for i in train_data:
    print(i)
    print('-'*30)
print('again')
for i in train_data:
    print(i)
    print('-'*30)

 

 

 

(base) devil@devilmaycry:/tmp$ python w.py 
上文的错误操作:
[tensor([3.1000, 4.1000, 5.1000], dtype=torch.float64), tensor([3.2000, 4.2000, 5.2000], dtype=torch.float64), tensor([3.3000, 4.3000, 5.3000], dtype=torch.float64)]
------------------------------
[tensor([1.1000, 2.1000], dtype=torch.float64), tensor([1.2000, 2.2000], dtype=torch.float64), tensor([1.3000, 2.3000], dtype=torch.float64)]
------------------------------
again
[tensor([3.1000, 5.1000, 1.1000], dtype=torch.float64), tensor([3.2000, 5.2000, 1.2000], dtype=torch.float64), tensor([3.3000, 5.3000, 1.3000], dtype=torch.float64)]
------------------------------
[tensor([2.1000, 4.1000], dtype=torch.float64), tensor([2.2000, 4.2000], dtype=torch.float64), tensor([2.3000, 4.3000], dtype=torch.float64)]


------------------------------

修正后的正确操作: tensor([[
2.1000, 2.2000, 2.3000], [1.1000, 1.2000, 1.3000], [3.1000, 3.2000, 3.3000]], dtype=torch.float64) ------------------------------ tensor([[4.1000, 4.2000, 4.3000], [5.1000, 5.2000, 5.3000]], dtype=torch.float64) ------------------------------ again tensor([[5.1000, 5.2000, 5.3000], [4.1000, 4.2000, 4.3000], [3.1000, 3.2000, 3.3000]], dtype=torch.float64) ------------------------------ tensor([[2.1000, 2.2000, 2.3000], [1.1000, 1.2000, 1.3000]], dtype=torch.float64) ------------------------------

 

可以看出  传入到   Dataset  中的对象必须是  torch  类型的 tensor  类型, 如果传入的是list则会得出错误结果。

 

 

 

-----------------------------------------------------------------------------------------------------

 

 

补充:

之所以发现上面的这个错误,是因为发现了下面的代码:

import numpy as np
from torchvision.datasets import mnist # 导入 pytorch 内置的 mnist 数据
from torch.utils.data import DataLoader
#from torch.utils.data import Dataset


def data_tf(x):
    x = np.array(x, dtype='float32') / 255
    x = (x - 0.5) / 0.5 # 数据预处理,标准化
    x = x.reshape((-1,)) # 拉平
    x = torch.from_numpy(x)
    return x


#Dataset
# 重新载入数据集,申明定义的数据变换
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True)
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)


train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)

 

从上面的   data_tf  函数中我们发现,  Dataset对象返回的是   torch 的  tensor 对象。

 


推荐阅读
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • 本文介绍了使用readlink命令获取文件的完整路径的简单方法,并提供了一个示例命令来打印文件的完整路径。共有28种解决方案可供选择。 ... [详细]
  • 合并列值-合并为一列问题需求:createtabletab(Aint,Bint,Cint)inserttabselect1,2,3unionallsel ... [详细]
  • 1Lock与ReadWriteLock1.1LockpublicinterfaceLock{voidlock();voidlockInterruptibl ... [详细]
  • 本文详细介绍了SQL日志收缩的方法,包括截断日志和删除不需要的旧日志记录。通过备份日志和使用DBCC SHRINKFILE命令可以实现日志的收缩。同时,还介绍了截断日志的原理和注意事项,包括不能截断事务日志的活动部分和MinLSN的确定方法。通过本文的方法,可以有效减小逻辑日志的大小,提高数据库的性能。 ... [详细]
  • Linux重启网络命令实例及关机和重启示例教程
    本文介绍了Linux系统中重启网络命令的实例,以及使用不同方式关机和重启系统的示例教程。包括使用图形界面和控制台访问系统的方法,以及使用shutdown命令进行系统关机和重启的句法和用法。 ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • 本文讨论了在Windows 8上安装gvim中插件时出现的错误加载问题。作者将EasyMotion插件放在了正确的位置,但加载时却出现了错误。作者提供了下载链接和之前放置插件的位置,并列出了出现的错误信息。 ... [详细]
  • 本文介绍了OC学习笔记中的@property和@synthesize,包括属性的定义和合成的使用方法。通过示例代码详细讲解了@property和@synthesize的作用和用法。 ... [详细]
  • 使用Ubuntu中的Python获取浏览器历史记录原文: ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • MATLAB函数重名问题解决方法及数据导入导出操作详解
    本文介绍了解决MATLAB函数重名的方法,并详细讲解了数据导入和导出的操作。包括使用菜单导入数据、在工作区直接新建变量、粘贴数据到.m文件或.txt文件并用load命令调用、使用save命令导出数据等方法。同时还介绍了使用dlmread函数调用数据的方法。通过本文的内容,读者可以更好地处理MATLAB中的函数重名问题,并掌握数据导入导出的各种操作。 ... [详细]
  • 本文介绍了一种轻巧方便的工具——集算器,通过使用集算器可以将文本日志变成结构化数据,然后可以使用SQL式查询。集算器利用集算语言的优点,将日志内容结构化为数据表结构,SPL支持直接对结构化的文件进行SQL查询,不再需要安装配置第三方数据库软件。本文还详细介绍了具体的实施过程。 ... [详细]
author-avatar
拍友2502906483
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有