热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Pytorch学习(六)

利用PyTorch完成Cifar10图像分类分类网络的基本结构先加载数据,然后将这个数据进行重组,组合成我们神经网络所需要的数据的形式(

利用PyTorch完成Cifar10图像分类


分类网络的基本结构

在这里插入图片描述

  • 先加载数据,然后将这个数据进行重组,组合成我们神经网络所需要的数据的形式(数据预处理/数据增强)。
  • 我们将数据丢到一个CNN网络中进行特征的提取。在特征提取完之后我们会得到一个N维的向量,而这N个向量就表示我们想要分的N个类别。
  • 通过一个损失函数来计算网络当前的损失,通过这个损失对网络进行反向传播,从而对参数进行调整。在进行BP的时候我们就需要定义我们网络的优化器,如梯度下降法等来完成对网络参数的迭代。迭代到模型收敛的时候,即loss很小或是loss几乎不变的时候,参数就训练完成了。
  • 有了这些参数,就可以构建一个网络。从而对我们接下来输入数据的一个推理。
  • 训练的过程是求解网络参数的过程
  • 推理的过程是已知网络结构,并且经过训练过程我们已经拿到网络参数,这个时候我们就构建出了前向推理的函数。利用这个函数,给定其输入数据x,就可以计算出预测的结果。

数据加载模块

在这里插入图片描述

  • 在pytorch中channel是优先排列的,我们需要将数据转换成 channel * h * w的数据

数据增强


  • 是一种将已有的数据进行一种扩充的手段,为了解决数据集不足的问题
  • 常见的一些数据增强的方法:进行一些随机的翻转/旋转 / 修改图像的亮度 饱和度 对比度等
  • 我们使用的是pytorch下的transformer类
    在这里插入图片描述

网络结构

在这里插入图片描述

类别概率分布


  • N维度向量对应N个类别
  • 采用FC层(对图像尺寸敏感),将其拉成一个向量。也可以使用卷积层(对图像尺寸敏感)或者是pooling层(pooling层无参数,所以对shape不敏感)
  • 通过一个softmax函数,将这N维向量映射到概率分布上面去 Si=ei/∑jejS_i = e^i / \sum_j e^jSi=ei/jej

损失


  • 使用交叉熵损失 nn.CrossEntropyLoss
  • 在分类问题中需要定义label。从标签转化为向量的是One-hot编码
  • [1,0,0]第一类 [0,1,0] 第二类 但是这种方式太硬了,所以需要label smoothing在这里插入图片描述

分类问题常用指标

在这里插入图片描述

  • PR曲线 ROC曲线 AUC曲线
    在这里插入图片描述

优化器


  • 推荐使用Adam,学习率衰减采用指数衰减
    在这里插入图片描述

cifar10数据介绍——读取——处理


cifar10数据介绍


  • cifar10是完成一个10分类,cifar100是完成一个100分类
  • cifar10数据集中一共有六万张图片,其中五万张是训练集(每个类别有5000张图片),一万张是测试集。图片大小是32*32
  • 数据集获取地址

cifar10数据的读取与保存

import os
import cv2
import numpy as np
import pickle def unpickle(file):with open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dictlabel_name = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]import glob
train_list = glob.glob("cifar-10-batches/data_batch_*")
test_list = glob.glob("cifar-10-batches/test_batch")
save_path = "cifar-10-batches/train"
test_path = "cifar-10-batches/test"
# for l in train_list:
# l_dict = unpickle(l)
# #print( l_dict.keys() )
# for im_idx, im_data in enumerate(l_dict[b'data']):
# #print(im_idx)
# #print(im_data) #这里显示的是一个向量,所以我们需要reshape一下 3*32*32# im_label = l_dict[b'labels'][im_idx]
# im_name = l_dict[b'filenames'][im_idx] # #print(im_label, im_name, im_data)
# im_label_name = label_name[im_label]
# im_data = np.reshape(im_data, [3,32,32])
# im_data = np.transpose(im_data, (1,2,0))# #cv2.imshow("im_data", cv2.resize(im_data, (200,200) ) )
# #cv2.waitKey(0)# if not os.path.exists("{}/{}".format(save_path, im_label_name)):
# os.mkdir("{}/{}".format(save_path, im_label_name))# cv2.imwrite("{}/{}/{}".format(save_path, im_label_name, im_name.decode("utf-8")), im_data)for l in test_list:l_dict = unpickle(l)#print( l_dict.keys() )for im_idx, im_data in enumerate(l_dict[b'data']):#print(im_idx)#print(im_data) #这里显示的是一个向量,所以我们需要reshape一下 3*32*32im_label = l_dict[b'labels'][im_idx]im_name = l_dict[b'filenames'][im_idx] #print(im_label, im_name, im_data)im_label_name = label_name[im_label]im_data = np.reshape(im_data, [3,32,32])im_data = np.transpose(im_data, (1,2,0))#cv2.imshow("im_data", cv2.resize(im_data, (200,200) ) )#cv2.waitKey(0)if not os.path.exists("{}/{}".format(test_path, im_label_name)):os.mkdir("{}/{}".format(test_path, im_label_name))cv2.imwrite("{}/{}/{}".format(test_path, im_label_name, im_name.decode("utf-8")), im_data)

cifar10数据的加载与处理

from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np
import globlabel_name = ["airplane", "automobile", "bird", "cat","deer","dog","frog","horse","ship","truck"]label_dict = {}for idx, name in enumerate(label_name):label_dict[name] = idxdef default_loader(path):return Image.open(path).covert("RGB")train_transform = transforms.Compose([transforms.RandomResizedCrop((28,28)),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(90),transforms.RandomGrayscale(0.1),transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),transforms.ToTensor()
])class MyDataset(Dataset):def __init__(self, im_list, transform = None, loader = default_loader):super(MyDataset,self).__init__()imgs = []for im_item in im_list:im_label_name = im_item.split("/")[-2]imgs.append([im_item, label_dict[im_label_name]])self.imgs = imgsself.transform = transformself.loader = loader# 定义对数据的读取和对数据的增强,然后返回图片的数据和labeldef __getitem__(self, index):im_path, im_label = self.imgs[index]im_data = self.loader(im_path)if self.transform is not None:im_data = self.transform(im_data)return im_data, im_labeldef __len__(self):return len(self.imgs)im_train_list = glob.glob("cifar-10-batches/train/*/*.png")
im_test_list = glob.glob("cifar-10-batches/test/*/*.png")train_dataset = MyDataset( im_train_list, transform = train_transform )
test_dataset = MyDataset( im_test_list, transform = transforms.ToTensor() )train_data_loader = DataLoader(dataset = train_dataset, batch_size = 66, shuffle = True, num_workers = 4)
test_data_loader = DataLoader(dataset = test_dataset, batch_size = 66, shuffle = False, num_workers = 4)print("num_of_train", len(train_dataset))
print("num_of_test", len(test_dataset))

使用VGGNet实现cifar10分类

网络定义的程序

import torch
import torch.nn as nn
import torch.nn.functional as Fclass VGGbase(nn.Module):def __init__(self):super(VGGbase, self).__init__()self.conv1 = nn.Sequrntial(nn.Conv2d(3, 64, kernal_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(64),nn.ReLU())self.max_pooling1 = nn.MaxPool2d(kernel_size = 2, stride = 2)self.conv2_1 = nn.Sequrntial(nn.Conv2d(64, 128, kernal_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(128),nn.ReLU())self.conv2_2 = nn.Sequrntial(nn.Conv2d(128, 128, kernal_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(128),nn.ReLU())self.max_pooling2 = nn.MaxPool2d(kernel_size = 2, stride = 2)self.conv3_1 = nn.Sequrntial(nn.Conv2d(128, 256, kernal_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(256),nn.ReLU())self.conv3_2 = nn.Sequrntial(nn.Conv2d(256, 256, kernal_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(256),nn.ReLU())self.max_pooling3 = nn.MaxPool2d(kernel_size = 2, stride = 2, padding=1)self.conv4_1 = nn.Sequrntial(nn.Conv2d(256, 512, kernal_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(256),nn.ReLU())self.conv4_2 = nn.Sequrntial(nn.Conv2d(512, 512, kernal_size = 3, stride = 1, padding = 1),nn.BatchNorm2d(512),nn.ReLU())self.max_pooling3 = nn.MaxPool2d(kernel_size = 2, stride = 2)#batchsize * 512 * 2 * 2 -- > batchsize * (512 * 4) self.fc = nn.Linear( 512 * 4, 10 )def forward(self, x):batchsize = x.size(0)out = self.conv1(x)out = self.max_pooling1(out)out = self.conv2_1(out)out = self.conv2_2(out)out = self.max_pooling2(out)out = self.conv3_1(out)out = self.conv3_2(out)out = self.max_pooling3(out)out = self.conv4_1(out)out = self.conv4_2(out)out = self.max_pooling4(out)out = out.view(batchsize, -1)out = self.fc(out)out = F.log_sofemax(out, dim=1)return out
def VGGNet():return VGGbase()

训练代码

import torch
import torch.nn as nn
import torchvision
from vggnet import VGGNet
from load_cifar10 import train_loader, test_loader
import os
import tensorboardX #学着使用tensorboarddevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")epoch_num = 200
lr = 0.01
batch_size = 128net = VGGNet().to(device)#lossloss_func = nn.CrossEntropyLoss()#optimizeroptimizer = torch.optim.Adam( net.parameters(), lr = lr )scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)if not os.path.exists("log"):os.mkdir("log")
writer = tensorboardX.SummaryWriter("log")step_n = 0for epoch in range(epoch_num):print("epoch is", epoch)net.train() #train BN dropoutfor i, data in enumerate(train_loader):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = net(inputs)loss = loss_func(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()_, pred = torch.max(outputs.data, dim=1)correct = pred.eq(labels.data).cpu().sum()print("step", i, "loss is:", loss.item(),"mini-batch correct is:", 100.0 * correct / batch_size)writer.add_scalar("train loss", loss.item(),global_step = step_n)writer.add_scalar("train correct", 100.0 * correct.item() / batch_size,global_step = step_n)step_n += 1if not os.path.exists("models"):os.mkdir("models")torch.save(net.state_dict(), "models/{}.pth".format(epoch + 1)) secheduler.step()print("lr is ", optimizer.state_dict()["param_groups"][0]["lr"])#测试脚本sum_loss = 0sum_correct = 0for i, data in enumerate(test_loader):net.eval()inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = net(inputs)loss = loss_func(outputs, labels)_, pred = torch.max(outputs.data, dim=1)correct = pred.eq(labels.data).cpu().sum()sum_loss += loss.item()sum_correct += correct.item()writer.add_scalar("test loss", loss.item(),global_step = step_n)writer.add_scalar("test correct", 100.0 * correct.item() / batch_size,global_step = step_n)#计算每个batch平均的loss和correcttest_loss = sum_loss * 1.0 / len(test_loader)test_correct = sum_correct * 100.0 / len(test_loader) / batch_sizeprint("epoch is", epoch + 1, "loss is:", test_loss,"test correct is:", test_correct)writer.close()


推荐阅读
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • 本文由编程笔记#小编为大家整理,主要介绍了logistic回归(线性和非线性)相关的知识,包括线性logistic回归的代码和数据集的分布情况。希望对你有一定的参考价值。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 本文介绍了如何使用n3-charts绘制以日期为x轴的数据,并提供了相应的代码示例。通过设置x轴的类型为日期,可以实现对日期数据的正确显示和处理。同时,还介绍了如何设置y轴的类型和其他相关参数。通过本文的学习,读者可以掌握使用n3-charts绘制日期数据的方法。 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • baresip android编译、运行教程1语音通话
    本文介绍了如何在安卓平台上编译和运行baresip android,包括下载相关的sdk和ndk,修改ndk路径和输出目录,以及创建一个c++的安卓工程并将目录考到cpp下。详细步骤可参考给出的链接和文档。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • Html5-Canvas实现简易的抽奖转盘效果
    本文介绍了如何使用Html5和Canvas标签来实现简易的抽奖转盘效果,同时使用了jQueryRotate.js旋转插件。文章中给出了主要的html和css代码,并展示了实现的基本效果。 ... [详细]
  • WhenIusepythontoapplythepymysqlmoduletoaddafieldtoatableinthemysqldatabase,itdo ... [详细]
  • 突破MIUI14限制,自定义胶囊图标、大图标样式,支持任意APP
    本文介绍了如何突破MIUI14的限制,实现自定义胶囊图标和大图标样式,并支持任意APP。需要一定的动手能力和主题设计师账号权限或者会主题pojie。详细步骤包括应用包名获取、素材制作和封包获取等。 ... [详细]
author-avatar
海哭的声音2602928847
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有