我是Pytorch的新手。在开始使用CNN进行训练之前,我一直在尝试学习如何查看输入的图像。我很难将图像更改为可与matplotlib一起使用的形式。
到目前为止,我已经尝试过了:
from multiprocessing import freeze_support import torch from torch import nn import torchvision from torch.autograd import Variable from torch.utils.data import DataLoader, Sampler from torchvision import datasets from torchvision.transforms import transforms from torch.optim import Adam import matplotlib.pyplot as plt import numpy as np import PIL num_classes = 5 batch_size = 100 num_of_workers = 5 DATA_PATH_TRAIN = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\train' DATA_PATH_TEST = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\test' trans = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.Resize(32), transforms.CenterCrop(32), transforms.ToPImage(), transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)) ]) train_dataset = datasets.ImageFolder(root=DATA_PATH_TRAIN, transform=trans) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers) def imshow(img): img = img / 2 + 0.5 # unnormalize npimg = img.numpy() print(npimg) plt.imshow(np.transpose(npimg, (1, 2, 0, 1))) def main(): # get some random training images dataiter = iter(train_loader) images, labels = dataiter.next() # show images imshow(images) # print labels print(' '.join('%5s' % classes[labels[j]] for j in range(4))) if __name__ == "__main__": main()
但是,这引发了错误:
[[0.27058825 0.18431371 0.31764707 ... 0.18823528 0.3882353 0.27450982] [0.23137254 0.11372548 0.24313724 ... 0.16862744 0.14117646 0.40784314] [0.25490198 0.19607842 0.30588236 ... 0.27450982 0.25882354 0.34509805] ... [0.2784314 0.21960783 0.2352941 ... 0.5803922 0.46666667 0.25882354] [0.26666668 0.16862744 0.23137254 ... 0.2901961 0.29803923 0.2509804 ] [0.30980393 0.39607844 0.28627452 ... 0.1490196 0.10588235 0.19607842]] [[0.2352941 0.06274509 0.15686274 ... 0.09411764 0.3019608 0.19215685] [0.22745097 0.07843137 0.12549019 ... 0.07843137 0.10588235 0.3019608 ] [0.20392156 0.13333333 0.1607843 ... 0.16862744 0.2117647 0.22745097] ... [0.18039215 0.16862744 0.1490196 ... 0.45882353 0.36078432 0.16470587] [0.1607843 0.10588235 0.14117646 ... 0.2117647 0.18039215 0.10980392] [0.18039215 0.3019608 0.2117647 ... 0.11372548 0.06274509 0.04705882]]] ... [[[0.8980392 0.8784314 0.8509804 ... 0.627451 0.627451 0.627451 ] [0.8509804 0.8235294 0.7921569 ... 0.54901963 0.5568628 0.56078434] [0.7921569 0.7529412 0.7176471 ... 0.47058824 0.48235294 0.49411765] ... [0.3764706 0.38431373 0.3764706 ... 0.4509804 0.43137255 0.39607844] [0.38431373 0.39607844 0.3882353 ... 0.4509804 0.43137255 0.39607844] [0.3882353 0.4 0.39607844 ... 0.44313726 0.42352942 0.39215687]] [[0.9254902 0.90588236 0.88235295 ... 0.60784316 0.6 0.5921569 ] [0.88235295 0.85490197 0.8235294 ... 0.5411765 0.5372549 0.53333336] [0.8235294 0.7882353 0.75686276 ... 0.47058824 0.47058824 0.47058824] ... [0.50980395 0.5176471 0.5137255 ... 0.58431375 0.5647059 0.53333336] [0.5137255 0.53333336 0.5254902 ... 0.58431375 0.5686275 0.53333336] [0.5176471 0.53333336 0.5294118 ... 0.5764706 0.56078434 0.5294118 ]] [[0.95686275 0.9372549 0.90588236 ... 0.18823528 0.19999999 0.20784312] [0.9098039 0.8784314 0.8352941 ... 0.1607843 0.17254901 0.18039215] [0.84313726 0.7921569 0.7490196 ... 0.1372549 0.14509803 0.15294117] ... [0.03921568 0.05490196 0.05098039 ... 0.11764705 0.09411764 0.02745098] [0.04705882 0.07843137 0.06666666 ... 0.12156862 0.10196078 0.03529412] [0.05098039 0.0745098 0.07843137 ... 0.12549019 0.10196078 0.04705882]]] [[[0.30588236 0.28627452 0.24313724 ... 0.2901961 0.26666668 0.21568626] [0.8156863 0.6666667 0.5921569 ... 0.18039215 0.23921567 0.21568626] [0.9019608 0.83137256 0.85490197 ... 0.21960783 0.36862746 0.23921567] ... [0.7058824 0.83137256 0.85490197 ... 0.2627451 0.24313724 0.20784312] [0.7137255 0.84313726 0.84705883 ... 0.26666668 0.29803923 0.21568626] [0.7254902 0.8235294 0.8392157 ... 0.2509804 0.27058825 0.2352941 ]] [[0.24705881 0.22745097 0.19215685 ... 0.2784314 0.25490198 0.19607842] [0.59607846 0.37254903 0.29803923 ... 0.16470587 0.22745097 0.20392156] [0.5921569 0.4509804 0.49803922 ... 0.20784312 0.3764706 0.2352941 ] ... [0.42352942 0.4627451 0.42352942 ... 0.23921567 0.23137254 0.19999999] [0.45882353 0.5176471 0.35686275 ... 0.23921567 0.26666668 0.19607842] [0.41568628 0.44313726 0.34901962 ... 0.21960783 0.23921567 0.21568626]] [[0.23137254 0.20784312 0.1490196 ... 0.30588236 0.28627452 0.19607842] [0.61960787 0.3764706 0.26666668 ... 0.16470587 0.24313724 0.21568626] [0.57254905 0.43137255 0.48235294 ... 0.2235294 0.40392157 0.25882354] ... [0.4 0.42352942 0.37254903 ... 0.25490198 0.24705881 0.21568626] [0.43137255 0.4509804 0.29411766 ... 0.25882354 0.28235295 0.20392156] [0.38431373 0.3529412 0.25490198 ... 0.2352941 0.25490198 0.23137254]]] [[[0.06274509 0.09019607 0.11372548 ... 0.5803922 0.5176471 0.59607846] [0.09411764 0.14509803 0.1372549 ... 0.5294118 0.49803922 0.5058824 ] [0.04705882 0.09411764 0.10196078 ... 0.45882353 0.42352942 0.38431373] ... [0.15294117 0.12941176 0.1607843 ... 0.85882354 0.8509804 0.80784315] [0.14509803 0.10588235 0.1607843 ... 0.8666667 0.85882354 0.8 ] [0.1490196 0.10588235 0.16470587 ... 0.827451 0.8156863 0.7921569 ]] [[0.06666666 0.12156862 0.17647058 ... 0.59607846 0.5529412 0.6039216 ] [0.07058823 0.10588235 0.11764705 ... 0.56078434 0.5254902 0.5372549 ] [0.03921568 0.0745098 0.09803921 ... 0.48235294 0.4392157 0.4117647 ] ... [0.2117647 0.14509803 0.2784314 ... 0.43137255 0.3529412 0.34117648] [0.2235294 0.11372548 0.2509804 ... 0.4509804 0.39607844 0.2509804 ] [0.25490198 0.12156862 0.24705881 ... 0.38039216 0.36078432 0.3254902 ]] [[0.05490196 0.09803921 0.12549019 ... 0.46666667 0.38039216 0.45490196] [0.06274509 0.09803921 0.10196078 ... 0.44705883 0.41568628 0.3882353 ] [0.03921568 0.06666666 0.0862745 ... 0.3764706 0.33333334 0.28235295] ... [0.12156862 0.14509803 0.16862744 ... 0.15686274 0.0745098 0.09411764] [0.10588235 0.11372548 0.16862744 ... 0.25882354 0.18431371 0.05490196] [0.12156862 0.11372548 0.17254901 ... 0.2352941 0.17254901 0.14117646]]]] Traceback (most recent call last): File "image_loader.py", line 51, inmain() File "image_loader.py", line 46, in main imshow(images) File "image_loader.py", line 38, in imshow plt.imshow(np.transpose(npimg, (1, 2, 0, 1))) File "C:\Users\Aeryes\AppData\Local\Programs\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 598, in transpose return _wrapfunc(a, 'transpose', axes) File "C:\Users\Aeryes\AppData\Local\Programs\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 51, in _wrapfunc return getattr(obj, method)(*args, **kwds) ValueError: repeated axis in transpose
我试图打印出数组以获得尺寸,但是我不知道该怎么做。这很令人困惑。
这是我的直接问题:在使用DataLoader对象中的张量进行训练之前,如何查看输入图像?
首先,dataloader
输出4维张量- [batch, channel, height, width]
。Matplotlib和其他图像处理库经常需要[height, width, channel]
。您使用转置是正确的,只是使用方式不正确。
您的图像很多,images
因此首先您需要选择一个图像(或编写一个for循环以保存所有图像)。这将很简单images[i]
,通常我会使用i=0
。
然后,转置应该将现在的[channel, height, width]
张量转换为一个张量[height, width, channel]
。为此np.transpose(image.numpy(), (1, 2, 0))
,请非常像您一样使用。
放在一起,你应该有
plt.imshow(np.transpose(images[0].numpy(), (1, 2, 0)))
有时您需要根据用例进行调用.detach()
(将这部分与计算图分开)和.cpu()
(将数据从GPU传输到CPU),具体取决于
plt.imshow(np.transpose(images[0].cpu().detach().numpy(), (1, 2, 0)))