import os
import cv2
import numpy as np
import pickle def unpickle(file):with open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dictlabel_name = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]import glob
train_list = glob.glob("cifar-10-batches/data_batch_*")
test_list = glob.glob("cifar-10-batches/test_batch")
save_path = "cifar-10-batches/train"
test_path = "cifar-10-batches/test"
# for l in train_list:
# l_dict = unpickle(l)
# #print( l_dict.keys() )
# for im_idx, im_data in enumerate(l_dict[b'data']):
# #print(im_idx)
# #print(im_data) #这里显示的是一个向量,所以我们需要reshape一下 3*32*32# im_label = l_dict[b'labels'][im_idx]
# im_name = l_dict[b'filenames'][im_idx] # #print(im_label, im_name, im_data)
# im_label_name = label_name[im_label]
# im_data = np.reshape(im_data, [3,32,32])
# im_data = np.transpose(im_data, (1,2,0))# #cv2.imshow("im_data", cv2.resize(im_data, (200,200) ) )
# #cv2.waitKey(0)# if not os.path.exists("{}/{}".format(save_path, im_label_name)):
# os.mkdir("{}/{}".format(save_path, im_label_name))# cv2.imwrite("{}/{}/{}".format(save_path, im_label_name, im_name.decode("utf-8")), im_data)for l in test_list:l_dict = unpickle(l)#print( l_dict.keys() )for im_idx, im_data in enumerate(l_dict[b'data']):#print(im_idx)#print(im_data) #这里显示的是一个向量,所以我们需要reshape一下 3*32*32im_label = l_dict[b'labels'][im_idx]im_name = l_dict[b'filenames'][im_idx] #print(im_label, im_name, im_data)im_label_name = label_name[im_label]im_data = np.reshape(im_data, [3,32,32])im_data = np.transpose(im_data, (1,2,0))#cv2.imshow("im_data", cv2.resize(im_data, (200,200) ) )#cv2.waitKey(0)if not os.path.exists("{}/{}".format(test_path, im_label_name)):os.mkdir("{}/{}".format(test_path, im_label_name))cv2.imwrite("{}/{}/{}".format(test_path, im_label_name, im_name.decode("utf-8")), im_data)
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np
import globlabel_name = ["airplane", "automobile", "bird", "cat","deer","dog","frog","horse","ship","truck"]label_dict = {}for idx, name in enumerate(label_name):label_dict[name] = idxdef default_loader(path):return Image.open(path).covert("RGB")train_transform = transforms.Compose([transforms.RandomResizedCrop((28,28)),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(90),transforms.RandomGrayscale(0.1),transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),transforms.ToTensor()
])class MyDataset(Dataset):def __init__(self, im_list, transform = None, loader = default_loader):super(MyDataset,self).__init__()imgs = []for im_item in im_list:im_label_name = im_item.split("/")[-2]imgs.append([im_item, label_dict[im_label_name]])self.imgs = imgsself.transform = transformself.loader = loader# 定义对数据的读取和对数据的增强,然后返回图片的数据和labeldef __getitem__(self, index):im_path, im_label = self.imgs[index]im_data = self.loader(im_path)if self.transform is not None:im_data = self.transform(im_data)return im_data, im_labeldef __len__(self):return len(self.imgs)im_train_list = glob.glob("cifar-10-batches/train/*/*.png")
im_test_list = glob.glob("cifar-10-batches/test/*/*.png")train_dataset = MyDataset( im_train_list, transform = train_transform )
test_dataset = MyDataset( im_test_list, transform = transforms.ToTensor() )train_data_loader = DataLoader(dataset = train_dataset, batch_size = 66, shuffle = True, num_workers = 4)
test_data_loader = DataLoader(dataset = test_dataset, batch_size = 66, shuffle = False, num_workers = 4)print("num_of_train", len(train_dataset))
print("num_of_test", len(test_dataset))
网络定义的程序
import torch
import torch.nn as nn
import torch.nn.functional as Fclass VGGbase(nn.Module):def __init__(self):super(VGGbase, self).__init__()self.conv1 = nn.Sequrntial(nn.Conv2d(3, 64, kernal_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(64),nn.ReLU())self.max_pooling1 = nn.MaxPool2d(kernel_size = 2, stride = 2)self.conv2_1 = nn.Sequrntial(nn.Conv2d(64, 128, kernal_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(128),nn.ReLU())self.conv2_2 = nn.Sequrntial(nn.Conv2d(128, 128, kernal_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(128),nn.ReLU())self.max_pooling2 = nn.MaxPool2d(kernel_size = 2, stride = 2)self.conv3_1 = nn.Sequrntial(nn.Conv2d(128, 256, kernal_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(256),nn.ReLU())self.conv3_2 = nn.Sequrntial(nn.Conv2d(256, 256, kernal_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(256),nn.ReLU())self.max_pooling3 = nn.MaxPool2d(kernel_size = 2, stride = 2, padding=1)self.conv4_1 = nn.Sequrntial(nn.Conv2d(256, 512, kernal_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(256),nn.ReLU())self.conv4_2 = nn.Sequrntial(nn.Conv2d(512, 512, kernal_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(512),nn.ReLU())self.max_pooling3 = nn.MaxPool2d(kernel_size = 2, stride = 2)#batchsize * 512 * 2 * 2 -- > batchsize * (512 * 4) self.fc = nn.Linear( 512 * 4, 10 )def forward(self, x):batchsize = x.size(0)out = self.conv1(x)out = self.max_pooling1(out)out = self.conv2_1(out)out = self.conv2_2(out)out = self.max_pooling2(out)out = self.conv3_1(out)out = self.conv3_2(out)out = self.max_pooling3(out)out = self.conv4_1(out)out = self.conv4_2(out)out = self.max_pooling4(out)out = out.view(batchsize, -1)out = self.fc(out)out = F.log_sofemax(out, dim=1)return out
def VGGNet():return VGGbase()
训练代码
import torch
import torch.nn as nn
import torchvision
from vggnet import VGGNet
from load_cifar10 import train_loader, test_loader
import os
import tensorboardX #学着使用tensorboarddevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")epoch_num = 200
lr = 0.01
batch_size = 128net = VGGNet().to(device)#lossloss_func = nn.CrossEntropyLoss()#optimizeroptimizer = torch.optim.Adam( net.parameters(), lr = lr )scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)if not os.path.exists("log"):os.mkdir("log")
writer = tensorboardX.SummaryWriter("log")step_n = 0for epoch in range(epoch_num):print("epoch is", epoch)net.train() #train BN dropoutfor i, data in enumerate(train_loader):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = net(inputs)loss = loss_func(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()_, pred = torch.max(outputs.data, dim=1)correct = pred.eq(labels.data).cpu().sum()print("step", i, "loss is:", loss.item(),"mini-batch correct is:", 100.0 * correct / batch_size)writer.add_scalar("train loss", loss.item(),global_step = step_n)writer.add_scalar("train correct", 100.0 * correct.item() / batch_size,global_step = step_n)step_n += 1if not os.path.exists("models"):os.mkdir("models")torch.save(net.state_dict(), "models/{}.pth".format(epoch + 1)) secheduler.step()print("lr is ", optimizer.state_dict()["param_groups"][0]["lr"])#测试脚本sum_loss = 0sum_correct = 0for i, data in enumerate(test_loader):net.eval()inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = net(inputs)loss = loss_func(outputs, labels)_, pred = torch.max(outputs.data, dim=1)correct = pred.eq(labels.data).cpu().sum()sum_loss += loss.item()sum_correct += correct.item()writer.add_scalar("test loss", loss.item(),global_step = step_n)writer.add_scalar("test correct", 100.0 * correct.item() / batch_size,global_step = step_n)#计算每个batch平均的loss和correcttest_loss = sum_loss * 1.0 / len(test_loader)test_correct = sum_correct * 100.0 / len(test_loader) / batch_sizeprint("epoch is", epoch + 1, "loss is:", test_loss,"test correct is:", test_correct)writer.close()