数据集的基本结构
可以参考官方文档 web documantation。主要有三个类:Dataset, Sampler and DataLoader。
-
Dataset:
代表数据集的抽象类;所有其他数据集都应该继承它。所有的子类都应该覆盖len(提供数据集的大小)和getitem(支持范围从0到len(self)的整形索引)。
-
Sampler:
所有采样器的基准类;每个采样器子类必须提供iter方法,提供一种遍历数据集元素的索引的方法,以及一个返回迭代器长度的len方法。
-
DataLoader:
组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
简单的数据集类:
train_images_path = "./data/train_images"
train_labels_path = "./data/train_labels"class RSDataset(Dataset):def __init__(self, input_root, mode="train", debug = False):super().__init__()self.input_root = input_rootself.mode = modeif debug == False:self.input_ids = sorted(img for img in os.listdir(self.input_root))else:self.input_ids = sorted(img for img in os.listdir(self.input_root))[:500]self.mask_transform = transforms.Compose([transforms.Lambda(to_monochrome),transforms.Lambda(to_tensor),])self.image_transform = transforms.Compose([transforms.ToTensor(),])self.transform = DualCompose([RandomFlip(),RandomRotate90(),Rotate(),Shift(),])def __len__(self):return len(self.input_ids)def __getitem__(self, idx):imageName = os.path.join(self.input_root,self.input_ids[idx])image = np.array(cv2.imread(imageName), dtype=np.float32)mask = np.array(cv2.imread(imageName.replace("train_images", "train_labels")))/255h, w, c = image.shapemask1 = np.zeros((h, w), dtype=int)if self.mode == "train":image, mask = self.transform(image, mask)mask1 = mask[:,:,0]return self.image_transform(image), self.mask_transform(mask1)else:mask1 = mask[:,:,0]return self.image_transform(image), self.mask_transform(mask1)
def build_loader(input_img_folder = "./data/train_images",batch_size = 16,num_workers = 4):num_train = len(sorted(img for img in os.listdir(input_img_folder)))indices = list(range(num_train))seed(128381)indices = sample(indices, len(indices))split = int(np.floor(0.15 * num_train))train_idx, valid_idx = indices[split:], indices[:split]train_sampler = SubsetRandomSampler(train_idx)train_dataset = RSDataset("./data/train_images","./data/train_labels",mode = "train",)val_dataset = RSDataset("./data/train_images","./data/train_labels",mode="valid",)valid_sampler = SubsetRandomSampler(valid_idx)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler,num_workers=num_workers, pin_memory=True)valid_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, sampler=valid_sampler,num_workers=num_workers, pin_memory=True)return train_loader, valid_loader