热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

np读取csv文件_[记录]Pytorch利用图片数据和csv标签文件建立Dataset

获取了后续需要训练的图片image和含有图片对应的标签csv文件之后,需要将这两者结合,建立后续训练的Dataset。importtorchimpor

获取了后续需要训练的图片image和含有图片对应的标签csv文件之后,需要将这两者结合,建立后续训练的Dataset。

ab105aa73bd61ef5e0c8a59cd732d68d.png

import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image#建立自己的dataset
class CreateDatasetFromImages(Dataset):def __init__(self, csv_path, file_path, resize_height=256, resize_width=256):"""Args:csv_path (string): csv 文件路径img_path (string): 图像文件所在路径transform: transform 操作"""# 需要调整后的照片尺寸,我这里每张图片的大小尺寸不一致#self.resize_height = resize_heightself.resize_width = resize_width# csv_path = "C:UsersandroidcatDesktopcancer_classificationWarwick QU Dataset (Released 2016_07_08)Grade_train.csv"self.file_path = file_pathself.to_tensor = transforms.ToTensor() #将数据转换成tensor形式# 读取 csv 文件#利用pandas读取csv文件self.data_info = pd.read_csv(csv_path, header=None) #header=None是去掉表头部分# 文件第一列包含图像文件的名称self.image_arr = np.asarray(self.data_info.iloc[1:, 0]) #self.data_info.iloc[1:,0表示读取第一列,从第二行开始一直读取到最后一行# 第四列是图像的 labelself.label_arr = np.asarray(self.data_info.iloc[1:, 4])# 计算 lengthself.data_len = len(self.data_info.index) - 1def __getitem__(self, index):# 从 image_arr中得到索引对应的文件名single_image_name = self.image_arr[index]# 读取图像文件img_as_img = Image.open(self.file_path + single_image_name + ".bmp")#如果需要将RGB三通道的图片转换成灰度图片可参考下面两行# if img_as_img.mode != 'L':# img_as_img = img_as_img.convert('L')#设置好需要转换的变量,还可以包括一系列的nomarlize等等操作transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()])img_as_img = transform(img_as_img)# 得到图像的 labellabel = self.label_arr[index]return (img_as_img, label) #返回每一个index对应的图片数据和对应的labeldef __len__(self):return self.data_len

然后使用Jupyter Notebook加载训练集图片之后,可以得到:

MyTrainDataset = CreateDatasetFromImages("C:/Users/androidcat/Desktop/cancer_classification/Warwick QU Dataset (Released 2016_07_08)/Grade_train.csv","C:/Users/androidcat/Desktop/cancer_classification/Warwick QU Dataset (Released 2016_07_08)/train/")
train_loader = torch.utils.data.DataLoader(dataset=MyTrainDataset,batch_size=1, shuffle=False,)
MyTrainDataset.data_info

4f3c934c992f6040301d48ab2456e96f.png



推荐阅读
author-avatar
qaqa
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有