热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

【Pytorch深度学习50篇】·······第二篇:【人脸识别】(5)

【Pytorch深度学习50篇】·······第二篇:【人脸识别】(5)-hello啊朋友们,时隔几日我又回来了,脱更了,因为我去驻厂了,驻厂的意思就是去厂里写代码,在没有网络的环

hello啊朋友们,时隔几日我又回来了,脱更了,因为我去驻厂了,驻厂的意思就是去厂里写代码,在没有网络的环境下,对我来说挑战也不小,对我任何程序员来说没网的话,ctrl+c和ctrl+v这一必杀技就没法用了,所以难顶啊。

3.训练篇

为什么直接就是3了,因为前面已经讲了1和2,不懂就去看,骗流量,哈哈哈。

闲话不多说,开始上训练代码,前面数据准备和网络搭建都已经完成了,现在就要开始训练了

import torch
import torch.nn as nn
import dataset
import my_net as nets
import os


if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_image_folder = r'D:\DATAS\manhua\manhua_tou'
    pre_trian_flag = False
    model_folder = r'D:\DATAS\manhua\models'
    lr = 0.001
    epoches = 100
    batch_size = 16

    # 数据准备
    train_data = dataset.dataset(train_image_folder)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=2)

    # 模型初始化
    if pre_trian_flag == True:
        model_path = os.path.join(model_folder, 'best.pth')
        if os.path.exists(model_path):
            net = torch.load(model_path, map_location=device)
            net.train()
            print('加载预训练模型成功')
        else:
            print('未找到模型或预训练模型,开始重新训练')
    else:
        net = nets.My_Net().to(device)  # 在net.py中自定义一个网络
        net.train()

    # 定义损失函数
    criterion = nn.CrossEntropyLoss()

    # 定义优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    loss_proess = 1
    print('开始训练')
    for epoch in range(epoches):
        train_loss = 0
        for index, (image, label) in enumerate(train_loader):
            image = image.to(device)
            label = label.to(device)
            output = net(image)
            loss = criterion(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            print('epoch[%s/%s]---iteration[%s/%s]---------------loss=%s'%(epoch+1,epoches,index+1,len(train_loader),loss.item()))

        train_loss = train_loss / len(train_loader)
        print('current epoch[%s] total_loss = '%(epoch+1), train_loss)
        if train_loss 

3.1引入的包

还是先说一下导入的包,请看

import torch
import torch.nn as nn
import dataset
import my_net as nets
import os

其中dataset和my_net就是之前提到的两个脚本,其他的就无须多言了

3.2定义的一些路径和超参数

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_image_folder = r'D:\DATAS\manhua\manhua_tou'
    pre_trian_flag = False
    model_folder = r'D:\DATAS\manhua\models'
    lr = 0.001
    epoches = 100
    batch_size = 16

device最后要么是‘cuda’要么是‘cpu’这个就看你电脑的配置了,有没有独立的显卡,一般来说,你搞深度学习没有显卡确实是有点太不方便了,游戏都没法玩,多累啊。哈哈

train_image_folder是训练图片的路径,在dataset篇的时候也讲过了,数据要怎么放,在来截图给你们看看吧

然后每个文件夹里面就是图片文件了,以gangan为例截图演示一下

pre_train_flag 是用来判定有没有预训练模型的标志

model_folder是用来保存模型的文件夹,运行之前,记得先创建一下这个文件夹,免得报错,其实程序里也可以直接加上自己生成这个文件夹的代码,我没加,就是想你报错了回来找我,心机

好了好了,到了最能体现调参侠经验的地方了,这个三个参数,lr,epoches,batch_size,他们是什么,中文名字叫,学习率,迭代次数,批数量。具体都是什么意思,咱们以后开个专题来讲一讲

3.3数据准备和模型定义

    # 数据准备
    train_data = dataset.dataset(train_image_folder)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=2)

    # 模型初始化
    if pre_trian_flag == True:
        model_path = os.path.join(model_folder, 'best.pth')
        if os.path.exists(model_path):
            net = torch.load(model_path, map_location=device)
            net.train()
            print('加载预训练模型成功')
        else:
            print('未找到模型或预训练模型,开始重新训练')
    else:
        net = nets.My_Net().to(device)  # 在net.py中自定义一个网络
        net.train()

是不是似曾相识啊,因为我们前面以及写过这个代码了,对不,那我们就不详细说明了

3.4优化器和损失函数的定义

    # 定义损失函数
    criterion = nn.CrossEntropyLoss()

    # 定义优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

一般分类的损失函数都会用到crossentropy,只是它是什么,怎么计算的,这里面涉及到了一些数学姿势,要慢慢来说,这里就先记住就行了

优化器呢一般也就是选择Adam,至于为什么,我到现在也还是半蒙状态,所以先不讲。一般盲选Adam就没错了。

3.5训练

    loss_proess = 1
    print('开始训练')
    for epoch in range(epoches):
        train_loss = 0
        for index, (image, label) in enumerate(train_loader):
            image = image.to(device)
            label = label.to(device)
            output = net(image)
            loss = criterion(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            print('epoch[%s/%s]---iteration[%s/%s]---------------loss=%s'%(epoch+1,epoches,index+1,len(train_loader),loss.item()))

        train_loss = train_loss / len(train_loader)
        print('current epoch[%s] total_loss = '%(epoch+1), train_loss)
        if train_loss 

过程就是,先从train_loader里面读数据,然后将数据都放到gpu上,然后数据送入网络,得到的输出和label做损失,得到loss,然后loss方向传播,用于调整网络里的参数,然后再下一次循环,不断的调整参数,使得loss越来越小,也就是说,输入的数据得到的输出结果就会月接近label,这就是我们要的效果。一切计算的过程都交给了计算机,你只需要等待就好了,nice!!我们看看跑起来是什么效果吧

你看total_loss是不是下降了很多了,这就是训练的魔力。

3.6 test的程序

上代码

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
import PIL.Image as pimg

if __name__ == '__main__':
    classes = ['pangpang','shoushou','gangan','xixi','haha']   #设置成你的类别名称
    model_path = os.path.join(r'D:\DATAS\manhua\models','best.pth')
    test_image_folder = r'D:\DATAS\manhua\test_img'
    image_size = 96
    save_folder = r'D:\DATAS\manhua\output_img'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 测试数据准备
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])


    net = torch.load(model_path).to(device)
    net.eval()

    for i in os.listdir(test_image_folder):
        image_path = os.path.join(test_image_folder,i)
        img_ = pimg.open(image_path)

        img = img_.resize((image_size,image_size))
        img = test_transform(img)
        try:
            img = img.view(1, 3, image_size, image_size).to(device)
        except:
            img = img.view(1, 1, image_size, image_size).to(device)

        pre_prob = net(img)
        pre_class = pre_prob.argmax(1).view(-1)
        print(classes[pre_class.item()])
        save_path = os.path.join(save_folder,classes[pre_class.item()])
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        img_.save(save_path + '/' + i)

 试一试能不能看懂test吧,如果看懂了,就学会了

3.7 整个项目的代码链接

链接:https://pan.baidu.com/s/1bZbDQ1f4bfumjkK3RUUDag 
提取码:wrk4 
截图以示清白

整个深度学习分类任务的代码就到自己了,你看看用你的数据集来试一试效果

另外要谢谢大家的支持,已经400个粉丝了,撒花

至此,敬礼,salute!!!


推荐阅读
  • 超级简单加解密工具的方案和功能
    本文介绍了一个超级简单的加解密工具的方案和功能。该工具可以读取文件头,并根据特定长度进行加密,加密后将加密部分写入源文件。同时,该工具也支持解密操作。加密和解密过程是可逆的。本文还提到了一些相关的功能和使用方法,并给出了Python代码示例。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 深入理解Java虚拟机的并发编程与性能优化
    本文主要介绍了Java内存模型与线程的相关概念,探讨了并发编程在服务端应用中的重要性。同时,介绍了Java语言和虚拟机提供的工具,帮助开发人员处理并发方面的问题,提高程序的并发能力和性能优化。文章指出,充分利用计算机处理器的能力和协调线程之间的并发操作是提高服务端程序性能的关键。 ... [详细]
  • Java太阳系小游戏分析和源码详解
    本文介绍了一个基于Java的太阳系小游戏的分析和源码详解。通过对面向对象的知识的学习和实践,作者实现了太阳系各行星绕太阳转的效果。文章详细介绍了游戏的设计思路和源码结构,包括工具类、常量、图片加载、面板等。通过这个小游戏的制作,读者可以巩固和应用所学的知识,如类的继承、方法的重载与重写、多态和封装等。 ... [详细]
  • 本文介绍了在Python3中如何使用选择文件对话框的格式打开和保存图片的方法。通过使用tkinter库中的filedialog模块的asksaveasfilename和askopenfilename函数,可以方便地选择要打开或保存的图片文件,并进行相关操作。具体的代码示例和操作步骤也被提供。 ... [详细]
  • Iamtryingtomakeaclassthatwillreadatextfileofnamesintoanarray,thenreturnthatarra ... [详细]
  • 本文介绍了使用kotlin实现动画效果的方法,包括上下移动、放大缩小、旋转等功能。通过代码示例演示了如何使用ObjectAnimator和AnimatorSet来实现动画效果,并提供了实现抖动效果的代码。同时还介绍了如何使用translationY和translationX来实现上下和左右移动的效果。最后还提供了一个anim_small.xml文件的代码示例,可以用来实现放大缩小的效果。 ... [详细]
  • Webpack5内置处理图片资源的配置方法
    本文介绍了在Webpack5中处理图片资源的配置方法。在Webpack4中,我们需要使用file-loader和url-loader来处理图片资源,但是在Webpack5中,这两个Loader的功能已经被内置到Webpack中,我们只需要简单配置即可实现图片资源的处理。本文还介绍了一些常用的配置方法,如匹配不同类型的图片文件、设置输出路径等。通过本文的学习,读者可以快速掌握Webpack5处理图片资源的方法。 ... [详细]
  • Python如何调用类里面的方法
    本文介绍了在Python中调用同一个类中的方法需要加上self参数,并且规范写法要求每个函数的第一个参数都为self。同时还介绍了如何调用另一个类中的方法。详细内容请阅读剩余部分。 ... [详细]
  • 怎么在PHP项目中实现一个HTTP断点续传功能发布时间:2021-01-1916:26:06来源:亿速云阅读:96作者:Le ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • Day2列表、字典、集合操作详解
    本文详细介绍了列表、字典、集合的操作方法,包括定义列表、访问列表元素、字符串操作、字典操作、集合操作、文件操作、字符编码与转码等内容。内容详实,适合初学者参考。 ... [详细]
  • 基于dlib的人脸68特征点提取(眨眼张嘴检测)python版本
    文章目录引言开发环境和库流程设计张嘴和闭眼的检测引言(1)利用Dlib官方训练好的模型“shape_predictor_68_face_landmarks.dat”进行68个点标定 ... [详细]
  • Android自定义控件绘图篇之Paint函数大汇总
    本文介绍了Android自定义控件绘图篇中的Paint函数大汇总,包括重置画笔、设置颜色、设置透明度、设置样式、设置宽度、设置抗锯齿等功能。通过学习这些函数,可以更好地掌握Paint的用法。 ... [详细]
  • MySQL多表数据库操作方法及子查询详解
    本文详细介绍了MySQL数据库的多表操作方法,包括增删改和单表查询,同时还解释了子查询的概念和用法。文章通过示例和步骤说明了如何进行数据的插入、删除和更新操作,以及如何执行单表查询和使用聚合函数进行统计。对于需要对MySQL数据库进行操作的读者来说,本文是一个非常实用的参考资料。 ... [详细]
author-avatar
老男孩2702938107
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有