热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于Paddle的计算机视觉入门教程【学习笔记】(5)Paddlex实现垃圾分类

我是雪天鱼,一名FPGA爱好者,研究方向是FPGA架构探索和数字IC设计。关注公众号【集成电路设计教程】,获取更多学习资料,

我是 雪天鱼,一名FPGA爱好者,研究方向是FPGA架构探索和数字IC设计。

关注公众号【集成电路设计教程】,获取更多学习资料,并拉你进“IC设计交流群”。
QQIC设计&FPGA&DL交流群 群号:866169462

原作者课程链接:https://www.bilibili.com/video/BV18b4y1J7a6?p=2

一、安装 PaddleX

首先需要安装 Microsoft Visual C++ 14.0,安装工具自取
链接:https://pan.baidu.com/s/1NfJF2x3ZEnGzvlEcHcjxxg
提取码:n0ll

下载好后解压,双击 exe进行安装

enter description here

再进入之前创建的 paddle 的虚拟环境,输入下述指令安装PaddleX:

pip install paddlex==2.1.0 -i https://mirror.baidu.com/pypi/simple

通过 pip list 查看是否安装成功

enter description here

二、准备数据集

链接:https://pan.baidu.com/s/12x6C6O0R_3xGjTHdGy8Syw
提取码:l5ft

下载好后解压即可。

enter description here

  • eval.txt 用于测试
  • train.txt 用于模型训练
  • labels.txt 每个文件夹对于的label

三、模型训练

打开pycharm,新建名为 paddle 的项目,设置好 Python 解释器为 paddle 虚拟环境中的python。然后新建 train.py 编写训练代码:

from paddlex import transforms as T
import paddlex as pdx# 1 定义训练集的数据增强算子
train_transforms = T.Compose([T.RandomCrop(crop_size=224),T.RandomHorizontalFlip(),T.Normalize()])# 2 定义测试集的数据增强算子
eval_transforms = T.Compose([T.ResizeByShort(short_size=256),T.CenterCrop(crop_size=224),T.Normalize()
])# 3 定义训练数据集和测试数据集
train_dataset = pdx.datasets.ImageNet(data_dir='rubbish',file_list='rubbish/train.txt',label_list='rubbish/labels.txt',transforms=train_transforms,shuffle=True)
eval_dataset = pdx.datasets.ImageNet(data_dir='rubbish',file_list='rubbish/eval.txt',label_list='rubbish/labels.txt',transforms=eval_transforms)num_classes = len(train_dataset.labels) # label 个数# 4 定义垃圾分类的CNN模型
model = pdx.cls.MobileNetV3_small(num_classes=num_classes)# 5 模型训练
model.train(num_epochs=10, # 训练总轮数train_dataset=train_dataset,train_batch_size=64, # 每次送入显卡进行训练计算的图片张数eval_dataset=eval_dataset,lr_decay_epochs=[4, 6, 8], # 学习率减小的轮数save_dir='output/mobilenetv3_small', # 输出文件夹设置use_vdl=True)

刚开始训练时,学习率大容易快速收敛,但多轮训练后就需要将其设置小点,以免错过 loss 最小值。

enter description here

可以看到训练出来最优的模型是第 9 轮所训练出来的模型。

生成的文件有:

enter description here

enter description here

  • model.yml 是模型标注文件,简单介绍该模型。

四、使用训练好的模型进行预测

模型既然训练好了,就可以使用了,新建一个 predict.py,输入:

import paddlex as pdx
model = pdx.load_model('output/mobilenetv3_small/best_model')
result = model.predict('1.jpg')
print("Predict Result: ", result)

这里我从网上随便找了个易拉罐的图片:

enter description here

输出结果:

enter description here

可以看到识别正确。

五、训练过程可视化

通过在对应 paddle 的虚拟环境下,输入visualdl --logdir output/mobilenetv3_small --port 8001 即可可视化查看训练过程中,学习率、loss等参数的变化情况。

enter description here


推荐阅读
author-avatar
heqiuhao
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有