热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

利用PyTorch快速实现分类任务

关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P

关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习PyTorch,找了很多自定义数据加载的方法,还是使用torch中封装的库函数好用,而且快捷,会根据路径自动返回对应的标签,下面的代码每一行都给了注释。import torchfrom torchvision import transforms, utilsfrom torchvision import datasetsimport torch.utils.dataimport matplotlib.pyplot as plt# 定义图像预处理transform1 = tranhttps://blog.csdn.net/weixin_55737425/article/details/122958584

这里给出一个模板,适合想要快速实现的朋友们(想要快速做出效果),不需要多少理论知识,只需要将文中的文件地址更改为自己的电脑上的地址即可。(注意图片的保存方式有一定的格式,详细可以查阅ImageFolder函数的用法)

此处每一行的代码都已经标记缘由和作用,如果还有疑惑,欢迎垂询问题!

import random
from torch.utils.data import DataLoader
from torchvision.models import resnet50
from imutils import paths
import torch.nn as nn
from torch import optim
import numpy
from torchvision import transforms, utils
import torch
from torchvision import datasets
import matplotlib.pyplot as pltdef load_data():transform1 = transforms.Compose([ # 这里最好加上一个中括号,否则会被认为是意外实参transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转,概率为0.3transforms.RandomVerticalFlip(p=0.5), # 随机垂直翻转,概率为0.3# transforms.CenterCrop((400, 400)),transforms.ToTensor(), # 转换成Tensor类型transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.255)) # 这里是为了和官方文档保持一致])batch_size = 8train_data = datasets.ImageFolder(r"C:\Users\asus\Desktop\cnn_data\cnn_data\data\training_data", transform=transform1)# print(train_data.imgs)# 加载数据train_data = DataLoader(train_data, batch_size=batch_size, shuffle=True)return train_data# def im_convert(tensor): # 这里可以不用理睬,我是想要显示原来图片的
# image = tensor
# image = image.numpy().squeeze()
# image = image.transpose(1, 2, 0)
# iamge = image*numpy.array(0.229, 0.224, 0.255) + numpy.array(0.485, 0.456, 0.406)
# image = image.clip(0, 1)
# return imagedef train(train_data):lr = 0.0001EPOCH = 12 # 可以自己调整,多一点会更好,但十分耗时间device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 利用GPU进行训练model = resnet50(pretrained=True).to(device) # 此处使用迁移学习的方法预加载权重,此处会下载一段时间model.train() # 设置运行模式in_channel = model.fc.in_features # 获取全连接层中输入的维数model.fc = nn.Linear(in_channel, 2) # 重新赋值全连接层criterion = nn.CrossEntropyLoss().to(device) # 分类问题使用交叉熵的方法optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) # 也可以使用Adam,效果也好,momentum根据文献资料,0.9为最优选择scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1) # 每经历2个epoch就衰减十分之一,也可以自己选择running_loss = 0for epoch in range(0, EPOCH):correct = 0for i, (data, target) in enumerate(train_data, 1):data = torch.autograd.Variable(data).to(device)target = torch.autograd.Variable(target).to(device)optimizer.zero_grad() # 清空上一次的梯度值output = model(data)loss = criterion(output, target)running_loss = loss.item()loss.backward()optimizer.step()prediction = torch.argmax(output, dim=1) # 返回维度为dim上最大值的索引correct += (prediction == target).sum().item() # 当prediction==target时会返回“1”,predicton和target在此处都是tensor类型,所以返回的是“tensor(1)”,之后通过item返回数值if i % 2 == 0:print("第{}个EPOCH,第{}个batch,当前损失为{}".format(epoch+1, i, running_loss))print("本轮训练的准确率为{:}".format(correct/len(train_data)))
if __name__ == '__main__':train_data = load_data()train(train_data)


推荐阅读
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 图像因存在错误而无法显示 ... [详细]
  • 小程序wxs中的时间格式化以及格式化时间和date时间互转
    本文介绍了在小程序wxs中进行时间格式化操作的问题,并提供了解决方法。同时还介绍了格式化时间和date时间的互相转换的方法。 ... [详细]
  • 颜色迁移(reinhard VS welsh)
    不要谈什么天分,运气,你需要的是一个截稿日,以及一个不交稿就能打爆你狗头的人,然后你就会被自己的才华吓到。------ ... [详细]
  • Java太阳系小游戏分析和源码详解
    本文介绍了一个基于Java的太阳系小游戏的分析和源码详解。通过对面向对象的知识的学习和实践,作者实现了太阳系各行星绕太阳转的效果。文章详细介绍了游戏的设计思路和源码结构,包括工具类、常量、图片加载、面板等。通过这个小游戏的制作,读者可以巩固和应用所学的知识,如类的继承、方法的重载与重写、多态和封装等。 ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • Iamtryingtomakeaclassthatwillreadatextfileofnamesintoanarray,thenreturnthatarra ... [详细]
  • VScode格式化文档换行或不换行的设置方法
    本文介绍了在VScode中设置格式化文档换行或不换行的方法,包括使用插件和修改settings.json文件的内容。详细步骤为:找到settings.json文件,将其中的代码替换为指定的代码。 ... [详细]
  • Nginx使用(server参数配置)
    本文介绍了Nginx的使用,重点讲解了server参数配置,包括端口号、主机名、根目录等内容。同时,还介绍了Nginx的反向代理功能。 ... [详细]
  • baresip android编译、运行教程1语音通话
    本文介绍了如何在安卓平台上编译和运行baresip android,包括下载相关的sdk和ndk,修改ndk路径和输出目录,以及创建一个c++的安卓工程并将目录考到cpp下。详细步骤可参考给出的链接和文档。 ... [详细]
  • 本文介绍了在Vue项目中如何结合Element UI解决连续上传多张图片及图片编辑的问题。作者强调了在编码前要明确需求和所需要的结果,并详细描述了自己的代码实现过程。 ... [详细]
  • Go GUIlxn/walk 学习3.菜单栏和工具栏的具体实现
    本文介绍了使用Go语言的GUI库lxn/walk实现菜单栏和工具栏的具体方法,包括消息窗口的产生、文件放置动作响应和提示框的应用。部分代码来自上一篇博客和lxn/walk官方示例。文章提供了学习GUI开发的实际案例和代码示例。 ... [详细]
  • Android自定义控件绘图篇之Paint函数大汇总
    本文介绍了Android自定义控件绘图篇中的Paint函数大汇总,包括重置画笔、设置颜色、设置透明度、设置样式、设置宽度、设置抗锯齿等功能。通过学习这些函数,可以更好地掌握Paint的用法。 ... [详细]
author-avatar
漫湾镇团委
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有