热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pythonDataSet+Dataloader深度学习编程细节_数据集pytorchDataset的构建与使用

深度学习中许多网络的设计都需数据集的预处理功能辅助,本文对DataSetDataloader的使用做介绍。DataSet构建(简单示例)构建数据集需要继承torch
  • 深度学习中许多网络的设计都需数据集的预处理功能辅助,本文对DataSet + Dataloader 的使用做介绍。

DataSet构建(简单示例)

        构建数据集需要继承torch.utils.data.dataset的Dataset类重写init,getitem(self, mask),len三个方法。然后使用torch.utils.data import DataLoader来加载你创建的数据集Dataset。

import argparse
import os
import random
import shutil
import time
import warnings
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as modelsimport numpy as np
import os, imageiofrom torch.utils.data.dataset import Dataset
class MyDataSet(Dataset):def __init__(self, data, label):#传入参数是我们的数据集(data)和标签集(label)self.data = dataself.label = labelself.length = data.shape[0]def __getitem__(self, mask):# 获取返回数据的方法,传入参数是一个index,也被叫做mask,就是我们对数据集的选择索引。在调用DataLoader时就会自己生成index,所以我们只需要写好方法即可。label = self.label[mask]data = self.data[mask]return label, datadef __len__(self):# print(self.length)return self.lengthtrain_set = MyDataSet(xb,yb)# xb,yb为所有的数据
# train_set = MyDataSet(data=X_train, label=Y_train)
num_epoch = 100 # number of epochs to train on
batch_size = 1024 # training batch size
train_data = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self,depth=4,mapping_size=2,hidden_size=256):super().__init__()layers = []layers.append(nn.Linear(mapping_size,hidden_size))layers.append(nn.ReLU(inplace=True))for _ in range(depth-2):layers.append(nn.Linear(hidden_size,hidden_size))layers.append(nn.ReLU(inplace=True))layers.append(nn.Linear(hidden_size,3))self.layers = nn.Sequential(*layers)def forward(self,x):return torch.sigmoid(self.layers(x))
model = MLP()
for epoch in range(num_epoch ):model.train()for batchsz, (label, data) in enumerate(train_data):# i表示第几个batch, data表示该batch对应的数据,包含data和对应的labelsprint("第 {} 个Batch size of label {} and size of data{}".format(batchsz, label.shape, data.shape))

图像的分割处理数据集的构建

添加链接描述
添加链接描述

构建自监督任务的数据集(用一个数据集构建正负样本)

from torch.utils.data.dataset import Dataset
class MyDataSet(Dataset):def __init__(self, data, label):#传入参数是我们的数据集(data)和标签集(label)self.data = dataself.label = labelself.length = data.shape[0]def __getitem__(self, mask):# 获取返回数据的方法,传入参数是一个index,也被叫做mask,就是我们对数据集的选择索引。在调用DataLoader时就会自己生成index,所以我们只需要写好方法即可。label = self.label[mask]data = self.data[mask]return label, datadef __len__(self):# print(self.length)return self.length

C&G

  • Pytorch如何合并多个dataloader

后续(+捕获异常)

image = np_load_frame(self.videos[video_name]['frame'][frame_name+i], self._resize_height, self._resize_width)
IndexError: list index out of range

先加个捕获异常:

def __getitem__(self, index):video_name = self.samples[index].split('/')[-2]frame_name = int(self.samples[index].split('/')[-1].split('.')[-2])batch = []for i in range(self._time_step+self._num_pred):try:image = np_load_frame(self.videos[video_name]['frame'][frame_name+i], self._resize_height, self._resize_width)except :print('error from --- model utils')print(frame_name)print(i)if self.transform is not None:batch.append(self.transform(image))return np.concatenate(batch, axis=0)


推荐阅读
  • 本人学习笔记,知识点均摘自于网络,用于学习和交流(如未注明出处,请提醒,将及时更正,谢谢)OS:我学习是为了上 ... [详细]
  • 本文介绍了C#中生成随机数的三种方法,并分析了其中存在的问题。首先介绍了使用Random类生成随机数的默认方法,但在高并发情况下可能会出现重复的情况。接着通过循环生成了一系列随机数,进一步突显了这个问题。文章指出,随机数生成在任何编程语言中都是必备的功能,但Random类生成的随机数并不可靠。最后,提出了需要寻找其他可靠的随机数生成方法的建议。 ... [详细]
  • 深度学习中的Vision Transformer (ViT)详解
    本文详细介绍了深度学习中的Vision Transformer (ViT)方法。首先介绍了相关工作和ViT的基本原理,包括图像块嵌入、可学习的嵌入、位置嵌入和Transformer编码器等。接着讨论了ViT的张量维度变化、归纳偏置与混合架构、微调及更高分辨率等方面。最后给出了实验结果和相关代码的链接。本文的研究表明,对于CV任务,直接应用纯Transformer架构于图像块序列是可行的,无需依赖于卷积网络。 ... [详细]
  • Python爬虫中使用正则表达式的方法和注意事项
    本文介绍了在Python爬虫中使用正则表达式的方法和注意事项。首先解释了爬虫的四个主要步骤,并强调了正则表达式在数据处理中的重要性。然后详细介绍了正则表达式的概念和用法,包括检索、替换和过滤文本的功能。同时提到了re模块是Python内置的用于处理正则表达式的模块,并给出了使用正则表达式时需要注意的特殊字符转义和原始字符串的用法。通过本文的学习,读者可以掌握在Python爬虫中使用正则表达式的技巧和方法。 ... [详细]
  • 本博文基于《Amalgamationofproteinsequence,structureandtextualinformationforimprovingprote ... [详细]
  • 【论文】ICLR 2020 九篇满分论文!!!
    点击上方,选择星标或置顶,每天给你送干货!阅读大概需要11分钟跟随小博主,每天进步一丢丢来自:深度学习技术前沿 ... [详细]
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
  • 语义分割系列3SegNet(pytorch实现)
    SegNet手稿最早是在2015年12月投出,和FCN属于同时期作品。稍晚于FCN,既然属于后来者,又是与FCN同属于语义分割网络 ... [详细]
  • 上一章讲了如何制作数据集,接下来我们使用mmcls来实现多标签分类。 ... [详细]
  • 湍流|低频_youcans 的 OpenCV 例程 200 篇106. 退化图像的逆滤波
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了youcans的OpenCV例程200篇106.退化图像的逆滤波相关的知识,希望对你有一定的参考价值。 ... [详细]
  • pytorch Dropout过拟合的操作
    这篇文章主要介绍了pytorchDropout过拟合的操作,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完 ... [详细]
  • 代码如下:#coding:utf-8importstring,os,sysimportnumpyasnpimportmatplotlib.py ... [详细]
  • 本文介绍了使用PHP实现断点续传乱序合并文件的方法和源码。由于网络原因,文件需要分割成多个部分发送,因此无法按顺序接收。文章中提供了merge2.php的源码,通过使用shuffle函数打乱文件读取顺序,实现了乱序合并文件的功能。同时,还介绍了filesize、glob、unlink、fopen等相关函数的使用。阅读本文可以了解如何使用PHP实现断点续传乱序合并文件的具体步骤。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 都会|可能会_###haohaohao###图神经网络之神器——PyTorch Geometric 上手 & 实战
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了###haohaohao###图神经网络之神器——PyTorchGeometric上手&实战相关的知识,希望对你有一定的参考价值。 ... [详细]
author-avatar
青春快乐1
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有