热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

探索YOLOv3源码第1篇训练

YOLO是一句美国的俗语,YouOnlyLiveOnce,人生苦短,及时行乐。本文主要分享,如何实现YOLOv3的算法细节&

YOLO是一句美国的俗语,You Only Live Once,人生苦短,及时行乐。


本文主要分享,如何实现YOLO v3的算法细节,Keras框架。这是第1篇,训练。当然还有第2篇,至第n篇,毕竟,这是一个完整版 :)

本文的GitHub源码:

https://github.com/SpikeKing/keras-yolo3-detection




1. 参数

模型的训练参数,共有5个,即:

(1) 已标注边界框的图片数据集,其格式如下:

图片的位置 框的4个坐标和1个类别ID (xmin,ymin,xmax,ymax,id) ...
dataset/image.jpg 788,351,832,426,0 805,208,855,270,0

(2) 标注框类别的汇总,即数据集中所标注物体的全部类别,例如:

aeroplane
bicycle
bird
...

(3) 预训练模型,用于迁移学习中的微调,可选YOLO v3已训练完成的COCO模型权重,即:

pretrained_path = 'model_data/yolo_weights.h5'

(4) 预测特征图的anchor框集合:


  • 3个尺度的特征图,每个特征图3个anchor框,共9个框,从小到大排列;
  • 框13在大尺度52x52特征图中使用,框46是中尺度26x26,框7~9是小尺度13x13;
  • 大尺度特征图用于检测小物体,小尺度检测大物体;
  • 9个anchor来源于边界框的K-Means聚类。

例如,COCO的anchors列表,如下:

10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326

(5) 图片输入尺寸,默认为416x416,选择416的原因是:


  • 图片尺寸满足32的倍数,在DarkNet网络中,执行5次步长为2卷积,降采样,其卷积操作如下:

x = DarknetConv2D_BN_Leaky(num_filters, (3, 3), strides=(2, 2))(x)

  • 在最底层时,特征图尺寸需要满足为奇数,如13,以保证中心点落在唯一框中。如果为偶数时,则中心点落在中心的4个框中,导致歧义。

2. 创建模型

创建YOLOv3的网络模型,输入:


  • input_shape:图片尺寸;
  • anchors:9个anchor box;
  • num_classes:类别数;
  • freeze_body:冻结模式,1是冻结DarkNet53的层,2是冻结全部,只保留最后3层;
  • weights_path:预训练模型的权重。

即:

model = create_model(input_shape, anchors, num_classes,freeze_body=2,weights_path=pretrained_path)

其中,网络的最后3层是:3个1x1的卷积层,用于将3个尺度的特征图,转换为3个尺度的预测值。

即:

out_filters = num_anchors * (num_classes + 5)
// ...
DarknetConv2D(out_filters, (1, 1))

结构如下:

conv2d_59 (Conv2D) (None, 13, 13, 18) 18450 leaky_re_lu_58[0][0]
conv2d_67 (Conv2D) (None, 26, 26, 18) 9234 leaky_re_lu_65[0][0]
conv2d_75 (Conv2D) (None, 52, 52, 18) 4626 leaky_re_lu_72[0][0]



3. 样本数量

样本洗牌,将数据集拆分为10份,训练9份,验证1份,比较简单。

实现:

val_split = 0.1 # 训练和验证的比例with open(annotation_path) as f:
lines = f.readlines()
np.random.seed(47)
np.random.shuffle(lines)
np.random.seed(None)
num_val = int(len(lines) * val_split) # 验证集数量num_train = len(lines) - num_val # 训练集数量



4. 第1阶段训练

第1阶段,冻结部分网络,只训练底层权重:


  • 优化器使用常见的Adam;
  • 损失函数,直接使用模型的输出y_pred,忽略真值y_true;

即:

model.compile(optimizer=Adam(lr=1e-3), loss={# 使用定制的 yolo_loss Lambda层'yolo_loss': lambda y_true, y_pred: y_pred}) # 损失函数

其中,关于损失函数yolo_loss,以及y_true和y_pred:


  • 把y_true当成输入,作为模型的多输入,把loss封装为层,作为输出;
  • 在模型中,最终输出的y_pred就是loss;
  • 在编译时,将loss设置为y_pred即可,无视y_true;
  • 在训练时,随意添加一个符合结构的y_true即可。

Python的Lambda表达式:

f = lambda y_true, y_pred: y_pred
print(f(1, 2)) # 输出2

模型fit数据,使用数据生成包装器,按批次生成训练和验证数据。最终,模型model存储权重。

实现如下:

batch_size = 32 # batch
model.fit_generator(data_generator_wrapper(lines[:num_train], batch_size, input_shape, anchors, num_classes),steps_per_epoch=max(1, num_train // batch_size),validation_data=data_generator_wrapper(lines[num_train:], batch_size, input_shape, anchors, num_classes),validation_steps=max(1, num_val // batch_size),epochs=50,initial_epoch=0,callbacks=[logging, checkpoint])
# 存储最终的去权重,再训练过程中,也通过回调存储
model.save_weights(log_dir + 'trained_weights_stage_1.h5')

同时,在训练过程中,也会不断保存,epoch完成的模型权重,设置参数为:


  • 只存储权重(save_weights_only);
  • 只存储最优结果(save_best_only);
  • 每隔3个epoch存储一次(period)。

即:

checkpoint = ModelCheckpoint(log_dir + 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',monitor='val_loss', save_weights_only=True,save_best_only=True, period=3) # 只存储weights权重



5. 第2阶段训练

第2阶段,使用第1阶段已训练完成的网络权重,继续训练:


  • 将全部的权重都设置为可训练,而在第1阶段中,则是冻结部分权重;
  • 优化器,仍是Adam,只是学习率有所下降,从1e-3减少至1e-4;
  • 损失函数,仍是只使用y_pred,忽略y_true。

实现:

for i in range(len(model.layers)):model.layers[i].trainable = Truemodel.compile(optimizer=Adam(lr=1e-4),loss={'yolo_loss': lambda y_true, y_pred: y_pred})

第2阶段的模型fit数据,与第1阶段类似,从第50个epoch开始,一直训练到第100个epoch,当触发条件时,则提前终止。额外增加了两个回调reduce_lr和early_stopping,用于控制训练提取终止的时机:


  • reduce_lr:当评价指标不在提升时,减少学习率,每次减少10%,当验证损失值,持续3次未减少时,则终止训练。
  • early_stopping:当验证集损失值,连续增加小于0时,持续10个epoch,则终止训练。

实现:

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, verbose=1) # 当评价指标不在提升时,减少学习率
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1) # 验证集准确率,下降前终止
batch_size = 32
model.fit_generator(data_generator_wrapper(lines[:num_train], batch_size, input_shape, anchors, num_classes),steps_per_epoch=max(1, num_train // batch_size),validation_data=data_generator_wrapper(lines[num_train:], batch_size, input_shape, anchors,num_classes),validation_steps=max(1, num_val // batch_size),epochs=100,initial_epoch=50,callbacks=[logging, checkpoint, reduce_lr, early_stopping])
model.save_weights(log_dir + 'trained_weights_final.h5')

至此,在第2阶段训练完成之后,输出的网络权重,就是最终的模型权重。




补充1. K-Means

K-Means算法是聚类算法,将一组数据划分为多个组,每组都含有一个中心。


在YOLOv3中,获取数据集的全部anchor box,通过K-Means算法,将这些边界框的宽高,聚类为9类,获取9个聚类中心,面积从小到大排列,作为9个anchor box。


模拟K-Means算法:


  • 创建测试点,X是数据,y是标签,如X:(300,2), y:(300,);
  • 将数据聚类为9类;
  • 输入数据X,训练;
  • 预测X的类别,为y_kmeans;
  • 使用scatter绘制散点图,颜色范围viridis;
  • 获取聚类中心cluster_centers_,以黑色点表示;

源码:

import matplotlib.pyplot as plt
import seaborn as sns
sns.set() # for plot styling
from sklearn.cluster import KMeans
from sklearn.datasets.samples_generator import make_blobsdef test_of_k_means():# 创建测试点,X是数据,y是标签,X:(300,2), y:(300,)X, y_true = make_blobs(n_samples=300, centers=9, cluster_std=0.60, random_state=0)kmeans = KMeans(n_clusters=9) # 将数据聚类kmeans.fit(X) # 数据Xy_kmeans = kmeans.predict(X) # 预测# 颜色范围viridis: https://matplotlib.org/examples/color/colormaps_reference.htmlplt.scatter(X[:, 0], X[:, 1], c=y_kmeans, s=20, cmap='viridis') # c是颜色,s是大小centers = kmeans.cluster_centers_ # 聚类的中心plt.scatter(centers[:, 0], centers[:, 1], c='black', s=40, alpha=0.5) # 中心点为黑色plt.show() # 展示if __name__ == '__main__':test_of_k_means()

输出:
在这里插入图片描述K-Means


补充2. EarlyStopping

EarlyStopping是Callback的子类,Callback用于指定在每个阶段开始和结束时,执行的操作。在Callback中,有已经实现的简单子类,如acc、val、loss和val_loss等,还有复杂子类,如ModelCheckpoint和TensorBoard等。

Callback的回调接口,如下:

def on_epoch_begin(self, epoch, logs=None):
def on_epoch_end(self, epoch, logs=None):
def on_batch_begin(self, batch, logs=None):
def on_batch_end(self, batch, logs=None):
def on_train_begin(self, logs=None):
def on_train_end(self, logs=None):

EarlyStopping是提前停止训练的Callback子类。具体地,当训练或验证集中的loss不再减小,即减小的程度小于某个阈值,则会停止训练。这样做,可以提高调参效率,避免浪费资源。

在model的fit数据时,以列表设置callbacks回调,支持设置多个Callback,如:

callbacks=[logging, checkpoint, reduce_lr, early_stopping]

EarlyStopping的参数:


  • monitor:监控数据的类型,支持acc、val_acc、loss、val_loss等;
  • min_delta:停止阈值,与mode参数配合,支持增加或下降;
  • mode:min是最少,max是最多,auto是自动,与min_delta配合;
  • patience:达到阈值之后,能够容忍的epoch数,避免停止在抖动中;
  • verbose:日志的繁杂程度,值越大,输出的信息越多。

min_delta和patience需要相互配合,避免模型停止在抖动的过程中。min_delta降低,patience减少;而min_delta增加,则patience增加。

例如:

early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1)

reference

@online{Wang2021Nov,
author = {Wang, C. L.},
title = {{探索 YOLO v3 源码 - 第1篇 训练}},
organization = {微信公众平台},
year = {2021},
month = {11},
date = {2021-11-22},
urldate = {2021-11-22},
note = {[Online; accessed 22. Nov. 2021]},
url = {https://mp.weixin.qq.com/s/T9LshbXoervdJDBuP564dQ},
abstract = {{本文主要分享,如何实现YOLO v3的算法细节,Keras框架。这是第1篇,训练。当然还有第2篇,至第n篇,毕竟,这是一个完整版 :)}}
}


推荐阅读
  • 3.223.28周学习总结中的贪心作业收获及困惑
    本文是对3.223.28周学习总结中的贪心作业进行总结,作者在解题过程中参考了他人的代码,但前提是要先理解题目并有解题思路。作者分享了自己在贪心作业中的收获,同时提到了一道让他困惑的题目,即input details部分引发的疑惑。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 深度学习中的Vision Transformer (ViT)详解
    本文详细介绍了深度学习中的Vision Transformer (ViT)方法。首先介绍了相关工作和ViT的基本原理,包括图像块嵌入、可学习的嵌入、位置嵌入和Transformer编码器等。接着讨论了ViT的张量维度变化、归纳偏置与混合架构、微调及更高分辨率等方面。最后给出了实验结果和相关代码的链接。本文的研究表明,对于CV任务,直接应用纯Transformer架构于图像块序列是可行的,无需依赖于卷积网络。 ... [详细]
  • 本文介绍了一个Python函数same_set,用于判断两个相等长度的数组是否包含相同的元素。函数会忽略元素的顺序和重复次数,如果两个数组包含相同的元素,则返回1,否则返回0。文章还提供了函数的具体实现代码和样例输入输出。 ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 如何用UE4制作2D游戏文档——计算篇
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了如何用UE4制作2D游戏文档——计算篇相关的知识,希望对你有一定的参考价值。 ... [详细]
  • CF:3D City Model(小思维)问题解析和代码实现
    本文通过解析CF:3D City Model问题,介绍了问题的背景和要求,并给出了相应的代码实现。该问题涉及到在一个矩形的网格上建造城市的情景,每个网格单元可以作为建筑的基础,建筑由多个立方体叠加而成。文章详细讲解了问题的解决思路,并给出了相应的代码实现供读者参考。 ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • 本文介绍了南邮ctf-web的writeup,包括签到题和md5 collision。在CTF比赛和渗透测试中,可以通过查看源代码、代码注释、页面隐藏元素、超链接和HTTP响应头部来寻找flag或提示信息。利用PHP弱类型,可以发现md5('QNKCDZO')='0e830400451993494058024219903391'和md5('240610708')='0e462097431906509019562988736854'。 ... [详细]
  • PDO MySQL
    PDOMySQL如果文章有成千上万篇,该怎样保存?数据保存有多种方式,比如单机文件、单机数据库(SQLite)、网络数据库(MySQL、MariaDB)等等。根据项目来选择,做We ... [详细]
  • 颜色迁移(reinhard VS welsh)
    不要谈什么天分,运气,你需要的是一个截稿日,以及一个不交稿就能打爆你狗头的人,然后你就会被自己的才华吓到。------ ... [详细]
  • 工作经验谈之-让百度地图API调用数据库内容 及详解
    这段时间,所在项目中要用到的一个模块,就是让数据库中的内容在百度地图上展现出来,如经纬度。主要实现以下几点功能:1.读取数据库中的经纬度值在百度上标注出来。2.点击标注弹出对应信息。3 ... [详细]
  • Matlab 中的一些小技巧(2)
    1.Ctrl+D打开子程序  在MATLAB的Editor中,将输入光标放到一个子程序名称中间,然后按Ctrl+D可以打开该子函数的m文件。当然这个子程序要在路径列表中(或在当前工作路径中)。实际上 ... [详细]
  • 时域|波形_语音处理基于matlab GUI音频数据处理含Matlab源码 1734期
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了语音处理基于matlabGUI音频数据处理含Matlab源码1734期相关的知识,希望对你有一定的参考价值。 ... [详细]
  • Thisissuewasoriginallyopenedbyashashicorp/terraform#5664.Itwasmigratedhe ... [详细]
author-avatar
AYAKASHIZ
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有