热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

YOLOV4Pytorch版本训练自建数据集和预测

1.程序下载本文程序核心部分完全参考开源代码:https:github.comWongKinYiuPyTorch_YOLOv4。只是从一种学习的角度去写了我的代码仓库,在基础上增加




1. 程序下载

本文程序核心部分完全参考开源代码:https://github.com/WongKinYiu/PyTorch_YOLOv4 。
只是从一种学习的角度去写了我的代码仓库,在基础上增加了一些常用的脚本(会持续更新)。
我的仓库地址为:https://github.com/hx-0614/yolov4-pytorch-learing

git clone https://github.com/hx-0614/yolov4-pytorch-learing.git

其中我新增的脚本文件,会在README.md文件中说明。代码结构如下(红色框内为我想对于原作者代码更改或者新增的脚本文件,核心代码没有更改)。
在这里插入图片描述


2. 代码结构解析

如上图所示代码结构。

|cfg # 存储模型结构配置文件的地方。
|------|yolov4.cfg
|------|......
|data # 数据读取文件参数。
|------|hyp.scratch.yaml # 训练相关的参数,我一般使用默认。
|------|coco.data # 训练集,验证集,数据集类别,数据集标签类别参数文件。
|------|coco.names # 数据集标签类别。这个文件一般测试的时候使用,貌似train没用到。
|------|......
|images # 存储模型结构的图片。
|------|......
|models # 构建模型的脚本文件。
|------|......
|utils # 程序执行过程中要用到的其他相关函数的脚本文件。
|------|......
|weights # 存放模型的文件,我是把预先训练模型存到了这里。
|------|......
|detect.py # 模型自带的批量推理代码,默认是输入一个图片文件夹,然后输出检测带有框的文件夹。
|test.py # 测试mAP的脚本文件。其中train.py 也会调用它,也可以单独使用,主要是针对coco.data里面的测试集路径进行测试。
|train.py # 训练脚本文件。
### 我新建的一些脚本 ###
|predict_image.py # 预测一张图片的脚本。
|predict_video.py # 预测一段视频的脚本。
|run_train.sh # 运行train.py的shell脚本。
### README.md ###
|README_V4.md # 原作者的README.md文件。
|README.md. # 我的README.md文件。
|requirements.txt # 需要的相关环境下载文件。


3. 训练自己的数据集


3.1. 首先需要准备自己的数据集

这里假设已经有每张图片对应的边框标签文件
大多数我们可能拿到的是VOC格式的数据集,先将其整理成下图所示的格式(其实不麻烦,有annotations和images就可以其他两个文件需要自己新建)图片来自于公开数据集Seaships
在这里插入图片描述

然后通过脚本制作YOLO-V4 Pytorch版本可以读取的格式(这个格式通用于u版的yolo系列,包括现在比赛常用的YOLO-V5)。数据集制作脚本百度搜索很多,我的代码也是搜索自百度,所以没有放到代码仓库里面,这里就在下面直接附上代码(…/代表自己数据集的路径)。

# makeTxt.py 制作训练集和测试集列表的脚本。
# 需要将脚本中的所有路径进行更改,手动改就好,改成自己数据集存储的地方,我建议大家使用绝对路径(不要出现中文)。
import os
import random
trainval_percent = 0.2 # 可自行进行调节
train_percent = 1
xmlfilepath = '../annotations'
txtsavepath = '../images'
total_xml = os.listdir(xmlfilepath)
num = len(total_xml)
list = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list, tv)
train = random.sample(trainval, tr)
# ftrainval = open('ImageSets/Main/trainval.txt', 'w')
ftest = open('../ImageSets/test.txt', 'w')
ftrain = open('../ImageSets/train.txt', 'w')
# fval = open('ImageSets/Main/val.txt', 'w')
for i in list:
name = total_xml[i][:-4] + '\n'
if i in trainval:
# ftrainval.write(name)
if i in train:
ftest.write(name)
# else:
# fval.write(name)
else:
ftrain.write(name)
# ftrainval.close()
ftrain.close()
# fval.close()
ftest.close()

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# voc_label.py 生成YOLO可以读取的数据集格式,最后会在labels文件夹下面生成每个图片的标签文件,txt结尾。
import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join
sets = ['train', 'test']
classes = ['a', 'b', 'c'] # 自己训练的类别
def convert(size, box):
dw = 1. / size[0]
dh = 1. / size[1]
x = (box[0] + box[1]) / 2.0
y = (box[2] + box[3]) / 2.0
w = box[1] - box[0]
h = box[3] - box[2]
x = x * dw
w = w * dw
y = y * dh
h = h * dh
return (x, y, w, h)
def convert_annotation(image_id):
in_file = open('../annotations/%s.xml' % (image_id))
out_file = open('../labels/%s.txt' % (image_id), 'w')
tree = ET.parse(in_file)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult) == 1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
float(xmlbox.find('ymax').text))
print(root.find('filename').text)
bb = convert((w, h), b)
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
wd = getcwd()
for image_set in sets:
if not os.path.exists('../labels/'):
os.makedirs('../labels/')
image_ids = open('../ImageSets/%s.txt' % (image_set)).read().strip().split()
list_file = open('../%s.txt' % (image_set), 'w')
for image_id in image_ids:
list_file.write('../images/%s.jpg\n' % (image_id))
convert_annotation(image_id)
list_file.close()

最后会在路径"…/"下面生成两个文件 train.txt, test.txt(路径写成…/train.txt, …/test.txt,后续需要根据自己的路径更改,建议写绝对路径),然后labels下面会生成每个图片对应的标签文件txt结尾。



3.2. 修改相关的配置文件

(1)修改data文件夹下面的配置文件。主要修改两个文件,也可以新建两个文件一个是 .data结尾一个是 .names结尾的文件,如下:

# my.data
# train和test的路径根据自己的生成情况更改,但是train.txt里面存储的是路径。不是图片的名字。
train: ../train.txt
val: ../test.txt
test: ../test.txt
# number of classes
nc: 3
# class names
names: ['a', 'b', 'c']

# my.names
# 这个文件存储类别,写的时候把这两行备注去掉以防出bug。
a
b
c

(2)修改cfg文件,模型结构文件,以yolov4.cfg为例子。
主要修改两个地方,yolo层和yolo层的上一层。如下图的红色框所示。
比如我们的类别为[‘a’, ‘b’, ‘c’] 所以yolo层里面的 classes改成3。yolo层上一层的filters改成(3 + 5)* 3,这个公式和yolo的输出维度有关系为 (类别数 + 5)* 3。每个cfg文件可能有多个yolo层,标准的YOLO-V4有3个yolo层需要更改三次。
![image.png](https://www.icode9.com/i/ll/?i=img_convert/c7453c98882c8dc4029ab349dee20786.png#clientId=u6d814a6e-7a23-4&from=paste&prism language-shell">pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
bash run_train.sh

注意:requirements.txt是按照原代码所写的,我这里是**习惯使用torch1.7.1,torchvision0.8.2版本**相对来说没有什么错误。如果遇到有些包没有,其实再pip就好了,我习惯于加上清华镜像安装环境。
在训练可能会遇到ModuleNotFoundError: No module named ‘mish_cuda’ ,需要根据bug提示把所有的from mish_cuda import MishCuda as Mish替换就好(原作者代码写了安装mish_cuda,但是我选择了直接替换了代码)。

class Mish(nn.Module):
def __init__(self):
super().__init__()
def forward(self,x):
x = x * (torch.tanh(F.softplus(x)))
return x


4. 预测图片和视频


4.1. predict_image.py

这里面的代码主要是方便大家进行测试一张图片,或者后续大家部署的时候可以使用此接口,在代码中写明了,每个函数的作用。这里面需要把模型读取,包括模型读取的尺寸,图片路径进行更改,图片输入尺寸等进行更改,后续会进行更新,写成更加简洁化的接口形式,也希望大家对我的代码提出意见。

if __name__ == "__main__":
"""
letterbox() 图片resize后空白区域填充函数
non_max_suppression() nms
scale_coords_() 图片尺寸还原
read_img() 读取单张图片,返回归一化后的张量 和 原图
init() 初始化模型
process() 主运行函数。输入参数为model,img 返回每张图得到的bbox结果 json格式
"""
img, img0 = read_img(path="data/samples/0056.jpg")
model = init()
res = process(model, img, img0)


4.2. predict_video.py

这是读取一个视频的代码,本质是调用了predict_image.py里面的预测接口进行。在predict_image.py里面的detect()函数是predict_video.py调用的接口,比较好理解,我就是在predict_video.py里面写好了读取视频每一帧的方法所以在调用detect()的时候不需要再根据图片路径读取图片,直接对读取好的图片变量进行后处理。




其他相关连接

下面是我写的一些博客,其实不一定非常详细但是都是按照我学习过程中总结的方式进行的。
YOLO-V1:https://blog.csdn.net/worewolf/article/details/116563934?spm=1001.2014.3001.5501
YOLO-V2:https://blog.csdn.net/worewolf/article/details/116569183?spm=1001.2014.3001.5501
YOLO-V4:https://blog.csdn.net/worewolf/article/details/116569208?spm=1001.2014.3001.5501
github:https://github.com/hx-0614/yolov4-pytorch-learing



后续

后续将补充转成tensorrt,以及如何修改代码等脚本。希望大家看到我的博客能提出指正。
在学习过程中也参考了很多其他大佬的代码,博客,视频,如有侵权将修正。



推荐阅读
  • 树莓派语音控制的配置方法和步骤
    本文介绍了在树莓派上实现语音控制的配置方法和步骤。首先感谢博主Eoman的帮助,文章参考了他的内容。树莓派的配置需要通过sudo raspi-config进行,然后使用Eoman的控制方法,即安装wiringPi库并编写控制引脚的脚本。具体的安装步骤和脚本编写方法在文章中详细介绍。 ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • 本文介绍了在Python3中如何使用选择文件对话框的格式打开和保存图片的方法。通过使用tkinter库中的filedialog模块的asksaveasfilename和askopenfilename函数,可以方便地选择要打开或保存的图片文件,并进行相关操作。具体的代码示例和操作步骤也被提供。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • Java容器中的compareto方法排序原理解析
    本文从源码解析Java容器中的compareto方法的排序原理,讲解了在使用数组存储数据时的限制以及存储效率的问题。同时提到了Redis的五大数据结构和list、set等知识点,回忆了作者大学时代的Java学习经历。文章以作者做的思维导图作为目录,展示了整个讲解过程。 ... [详细]
  • Java太阳系小游戏分析和源码详解
    本文介绍了一个基于Java的太阳系小游戏的分析和源码详解。通过对面向对象的知识的学习和实践,作者实现了太阳系各行星绕太阳转的效果。文章详细介绍了游戏的设计思路和源码结构,包括工具类、常量、图片加载、面板等。通过这个小游戏的制作,读者可以巩固和应用所学的知识,如类的继承、方法的重载与重写、多态和封装等。 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 本文介绍了在开发Android新闻App时,搭建本地服务器的步骤。通过使用XAMPP软件,可以一键式搭建起开发环境,包括Apache、MySQL、PHP、PERL。在本地服务器上新建数据库和表,并设置相应的属性。最后,给出了创建new表的SQL语句。这个教程适合初学者参考。 ... [详细]
  • 本文分享了一个关于在C#中使用异步代码的问题,作者在控制台中运行时代码正常工作,但在Windows窗体中却无法正常工作。作者尝试搜索局域网上的主机,但在窗体中计数器没有减少。文章提供了相关的代码和解决思路。 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 本文介绍了Redis的基础数据结构string的应用场景,并以面试的形式进行问答讲解,帮助读者更好地理解和应用Redis。同时,描述了一位面试者的心理状态和面试官的行为。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • 解决Cydia数据库错误:could not open file /var/lib/dpkg/status 的方法
    本文介绍了解决iOS系统中Cydia数据库错误的方法。通过使用苹果电脑上的Impactor工具和NewTerm软件,以及ifunbox工具和终端命令,可以解决该问题。具体步骤包括下载所需工具、连接手机到电脑、安装NewTerm、下载ifunbox并注册Dropbox账号、下载并解压lib.zip文件、将lib文件夹拖入Books文件夹中,并将lib文件夹拷贝到/var/目录下。以上方法适用于已经越狱且出现Cydia数据库错误的iPhone手机。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 计算机存储系统的层次结构及其优势
    本文介绍了计算机存储系统的层次结构,包括高速缓存、主存储器和辅助存储器三个层次。通过分层存储数据可以提高程序的执行效率。计算机存储系统的层次结构将各种不同存储容量、存取速度和价格的存储器有机组合成整体,形成可寻址存储空间比主存储器空间大得多的存储整体。由于辅助存储器容量大、价格低,使得整体存储系统的平均价格降低。同时,高速缓存的存取速度可以和CPU的工作速度相匹配,进一步提高程序执行效率。 ... [详细]
author-avatar
俊廷淑易89
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有