- 深度学习中许多网络的设计都需数据集的预处理功能辅助,本文对DataSet + Dataloader 的使用做介绍。
DataSet构建(简单示例)
构建数据集需要继承torch.utils.data.dataset的Dataset类重写init,getitem(self, mask),len三个方法。然后使用torch.utils.data import DataLoader来加载你创建的数据集Dataset。
import argparse
import os
import random
import shutil
import time
import warnings
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as modelsimport numpy as np
import os, imageiofrom torch.utils.data.dataset import Dataset
class MyDataSet(Dataset):def __init__(self, data, label):self.data = dataself.label = labelself.length = data.shape[0]def __getitem__(self, mask):label = self.label[mask]data = self.data[mask]return label, datadef __len__(self):return self.lengthtrain_set = MyDataSet(xb,yb)
num_epoch = 100
batch_size = 1024
train_data = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self,depth=4,mapping_size=2,hidden_size=256):super().__init__()layers = []layers.append(nn.Linear(mapping_size,hidden_size))layers.append(nn.ReLU(inplace=True))for _ in range(depth-2):layers.append(nn.Linear(hidden_size,hidden_size))layers.append(nn.ReLU(inplace=True))layers.append(nn.Linear(hidden_size,3))self.layers = nn.Sequential(*layers)def forward(self,x):return torch.sigmoid(self.layers(x))
model = MLP()
for epoch in range(num_epoch ):model.train()for batchsz, (label, data) in enumerate(train_data):print("第 {} 个Batch size of label {} and size of data{}".format(batchsz, label.shape, data.shape))
图像的分割处理数据集的构建
添加链接描述
添加链接描述
构建自监督任务的数据集(用一个数据集构建正负样本)
from torch.utils.data.dataset import Dataset
class MyDataSet(Dataset):def __init__(self, data, label):self.data = dataself.label = labelself.length = data.shape[0]def __getitem__(self, mask):label = self.label[mask]data = self.data[mask]return label, datadef __len__(self):return self.length
C&G
后续(+捕获异常)
image = np_load_frame(self.videos[video_name]['frame'][frame_name+i], self._resize_height, self._resize_width)
IndexError: list index out of range
先加个捕获异常:
def __getitem__(self, index):video_name = self.samples[index].split('/')[-2]frame_name = int(self.samples[index].split('/')[-1].split('.')[-2])batch = []for i in range(self._time_step+self._num_pred):try:image = np_load_frame(self.videos[video_name]['frame'][frame_name+i], self._resize_height, self._resize_width)except :print('error from --- model utils')print(frame_name)print(i)if self.transform is not None:batch.append(self.transform(image))return np.concatenate(batch, axis=0)