热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

开发笔记:PaddlePaddle系列CIFAR10图像分类

本文由编程笔记#小编为大家整理,主要介绍了PaddlePaddle系列CIFAR-10图像分类相关的知识,希望对你有一定的参考价值。前言
本文由编程笔记#小编为大家整理,主要介绍了PaddlePaddle系列CIFAR-10图像分类相关的知识,希望对你有一定的参考价值。



前言

本文与前文对手写数字识别分类基本类似的,同样图像作为输入,类别作为输出。这里不同的是,不仅仅是使用简单的卷积神经网络加上全连接层的模型。卷积神经网络大火以来,发展出来许多经典的卷积神经网络模型,包括VGG、ResNet、AlexNet等等。下面将针对CIFAR-10数据集,对图像进行分类。

 


1、CIFAR-10数据集、Reader创建

CIFAR-10数据集分为5个batch的训练集和1个batch的测试集,每个batch包含10,000张图片。每张图像尺寸为32*32的RGB图像,且包含有标签。一共有10个标签:airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck十个类别。

技术分享图片

我在CIFAR-10网站中下载的是[CIFAR-10 python version](http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)。数据集完成后,解压得到上述六个文件。上述六个文件都是字典文件,使用cPickle模块即可读入。字典中‘data’需要重新定义维度为1000*32*32*3,维度分别代表[N H W C],即10,000张32*32尺寸的三通道(RGB)图像,再经过转换成为paddlepaddle读取的[N C H W ]维度形式;而字典‘labels’为10000个标签。如此一来,可以建立读取CIFAR-10的reader(与官方例程不同),如下:


技术分享图片技术分享图片

def reader_creator(ROOT,istrain=True,cycle=False):
def load_CIFAR_batch(filename):
""" load single batch of cifar """
with open(filename,
rb) as f:
datadict
= Pickle.load(f)
X
= datadict[data]
Y
= datadict[labels]
""" (N C H W) transpose to (N H W C) """
X
= X.reshape(10000,3,32,32).transpose(0,2,3,1).astype(float)
Y
= np.array(Y)
return X,Y
def reader():
while True:
if istrain:
for b in range(1,6):
f
= os.path.join(ROOT,data_batch_%d%(b))
X,Y
= load_CIFAR_batch(f)
length
= X.shape[0]
for i in range(length):
yield X[i],Y[i]
if not cycle:
break
else:
f
= os.path.join(ROOT,test_batch)
X,Y
= load_CIFAR_batch(f)
length
= X.shape[0]
for i in range(length):
yield X[i],Y[i]
if not cycle:
break
return reader


View Code

 


2、VGG网络

VGG网络采用“减小卷积核大小,增加卷积核数量”的思想改造而成,这里直接采用paddlepaddle例程中的VGG网络了,值得提醒的是paddlepaddle中直接有函数img_conv_group提供卷积、池化、dropout一组操作,所以根据VGG的模型,前面卷积层可以划分为5组,然后再经过3层的全连接层得到结果。

技术分享图片

PaddlePaddle例程中根据上图D网络,加入dorpout:


def vgg_bn_drop(input):
def conv_block(ipt, num_filter, groups, dropouts):
return fluid.nets.img_conv_group(
input
=ipt,
#一组的卷积层的卷积核总数,组成list[num_filter num_filter ...]
conv_num_filter=[num_filter] * groups,
conv_filter_size
=3,
conv_act
=relu,
conv_with_batchnorm
=True,
#每组卷积层各层的droput概率
conv_batchnorm_drop_rate=dropouts,
pool_size
=2,
pool_stride
=2,
pool_type
=max)
conv1
= conv_block(input, 64, 2, [0.3, 0]) #[0.3 0]即为第一组两层的dorpout概率,下同
conv2 = conv_block(conv1, 128, 2, [0.4, 0])
conv3
= conv_block(conv2, 256, 3, [0.4, 0.4, 0])
conv4
= conv_block(conv3, 512, 3, [0.4, 0.4, 0])
conv5
= conv_block(conv4, 512, 3, [0.4, 0.4, 0])
drop
= fluid.layers.dropout(x=conv5, dropout_prob=0.5)
fc1
= fluid.layers.fc(input=drop, size=512, act=None)
bn
= fluid.layers.batch_norm(input=fc1, act=relu)
drop2
= fluid.layers.dropout(x=bn, dropout_prob=0.5)
fc2
= fluid.layers.fc(input=drop2, size=512, act=None)
predict
= fluid.layers.fc(input=fc2, size=10, act=softmax)
return predict

 


3、训练

训练程序与上一节例程一样,同样是选取交叉熵作为损失函数,不多累赘讲述。


技术分享图片技术分享图片

def train_network():
predict
= inference_network()
label
= fluid.layers.data(name=label,shape=[1],dtype=int64)
cost
= fluid.layers.cross_entropy(input=predict,label=label)
avg_cost
= fluid.layers.mean(cost)
accuracy
= fluid.layers.accuracy(input=predict,label=label)
return [avg_cost,accuracy]
def optimizer_program():
return fluid.optimizer.Adam(learning_rate=0.001)
def train(data_path,save_path):
BATCH_SIZE
= 128
EPOCH_NUM
= 2
train_reader
= paddle.batch(
paddle.reader.shuffle(reader_creator(data_path),buf_size
=50000),
batch_size
= BATCH_SIZE)
test_reader
= paddle.batch(
reader_creator(data_path,False),
batch_size
=BATCH_SIZE)
def event_handler(event):
if isinstance(event, fluid.EndStepEvent):
if event.step % 100 == 0:
print("
Pass %d, Epoch %d, Cost %f, Acc %f
" %
(event.step, event.epoch, event.metrics[0],
event.metrics[
1]))
else:
sys.stdout.write(
.)
sys.stdout.flush()
if isinstance(event, fluid.EndEpochEvent):
avg_cost, accuracy
= trainer.test(
reader
=test_reader, feed_order=[image, label])
print(
Test with Pass {0}, Loss {1:2.2}, Acc {2:2.2}
.format(
event.epoch, avg_cost, accuracy))
if save_path is not None:
trainer.save_params(save_path)
place
= fluid.CUDAPlace(0)
trainer
= fluid.Trainer(
train_func
=train_network, optimizer_func=optimizer_program, place=place)
trainer.train(
reader
=train_reader,
num_epochs
=EPOCH_NUM,
event_handler
=event_handler,
feed_order
=[image, label])


View Code

4、测试接口

测试接口也类似,需要特别注意的是图像维度要改为[N C H W]的顺序!


技术分享图片技术分享图片

def infer(params_dir):
place
= fluid.CUDAPlace(0)
inferencer
= fluid.Inferencer(
infer_func
=inference_network, param_path=params_dir, place=place)
# Prepare testing data.
from PIL import Image
import numpy as np
import os
def load_image(file):
im
= Image.open(file)
im
= im.resize((32, 32), Image.ANTIALIAS)
im
= np.array(im).astype(np.float32)
"""transpose [H W C] to [C H W]"""
im
= im.transpose((2, 0, 1))
im
= im / 255.0
# Add one dimension, [N C H W] N=1
im = np.expand_dims(im, axis=0)
return im
cur_dir
= os.path.dirname(os.path.realpath(__file__))
img
= load_image(cur_dir + /dog.png)
# inference
results = inferencer.infer({image: img})
print(results)
lab
= np.argsort(results) # probs and lab are the results of one batch data
print("infer results: ", cifar_classes[lab[0][0][-1]])


View Code

5、运行结果

由于笔者没有GPU服务器,所以只迭代了50次,已经用了8个多小时,但是准确率只有15.6%,测试集方面准确率有17%,效果不理想,用于验证的结果也是错的!


Pass 300, Epoch 49, Cost 2.261115, Acc 0.156250
.........................................................................................
Test with Pass
49, Loss 2.2, Acc 0.17
Classify the cifar10 images...
[array([[
0.05997971, 0.13485196, 0.096842 , 0.09973737, 0.11053724,
0.08180068, 0.13847008, 0.08627985, 0.06851784, 0.12298328]],
dtype
=float32)]
infer results: frog

 


结语

网络比较深,且数据集比较大,训练时间比较长,普通笔记本上面的GT840M聊以胜无吧。

 

本文代码:02_cifar

参考:book/03.image_classification/

 


推荐阅读
  • 基于dlib的人脸68特征点提取(眨眼张嘴检测)python版本
    文章目录引言开发环境和库流程设计张嘴和闭眼的检测引言(1)利用Dlib官方训练好的模型“shape_predictor_68_face_landmarks.dat”进行68个点标定 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • 本文介绍了在Python3中如何使用选择文件对话框的格式打开和保存图片的方法。通过使用tkinter库中的filedialog模块的asksaveasfilename和askopenfilename函数,可以方便地选择要打开或保存的图片文件,并进行相关操作。具体的代码示例和操作步骤也被提供。 ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • IjustinheritedsomewebpageswhichusesMooTools.IneverusedMooTools.NowIneedtoaddsomef ... [详细]
  • EzPP 0.2发布,新增YAML布局渲染功能
    EzPP发布了0.2.1版本,新增了YAML布局渲染功能,可以将YAML文件渲染为图片,并且可以复用YAML作为模版,通过传递不同参数生成不同的图片。这个功能可以用于绘制Logo、封面或其他图片,让用户不需要安装或卸载Photoshop。文章还提供了一个入门例子,介绍了使用ezpp的基本渲染方法,以及如何使用canvas、text类元素、自定义字体等。 ... [详细]
  • 本文介绍了Python对Excel文件的读取方法,包括模块的安装和使用。通过安装xlrd、xlwt、xlutils、pyExcelerator等模块,可以实现对Excel文件的读取和处理。具体的读取方法包括打开excel文件、抓取所有sheet的名称、定位到指定的表单等。本文提供了两种定位表单的方式,并给出了相应的代码示例。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • 本文介绍了Python异常的捕获、传递与抛出操作,并提供了相关的操作示例。通过异常的捕获和传递,可以有效处理程序中的错误情况。同时,还介绍了如何主动抛出异常。通过本文的学习,读者可以掌握Python中异常处理的基本方法和技巧。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • 自动轮播,反转播放的ViewPagerAdapter的使用方法和效果展示
    本文介绍了如何使用自动轮播、反转播放的ViewPagerAdapter,并展示了其效果。该ViewPagerAdapter支持无限循环、触摸暂停、切换缩放等功能。同时提供了使用GIF.gif的示例和github地址。通过LoopFragmentPagerAdapter类的getActualCount、getActualItem和getActualPagerTitle方法可以实现自定义的循环效果和标题展示。 ... [详细]
  • 本文介绍了在iOS开发中使用UITextField实现字符限制的方法,包括利用代理方法和使用BNTextField-Limit库的实现策略。通过这些方法,开发者可以方便地限制UITextField的字符个数和输入规则。 ... [详细]
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
  • 本文介绍了一个编程问题,要求求解一个给定n阶方阵的鞍点个数。通过输入格式的描述,可以了解到输入的是一个n阶方阵,每个元素都是整数。通过输出格式的描述,可以了解到输出的是鞍点的个数。通过题目集全集传送门,可以了解到提供了两个函数is_line_max和is_rank_min,用于判断一个元素是否为鞍点。本文还提供了三个样例,分别展示了不同情况下的输入和输出。 ... [详细]
author-avatar
捕鱼达人2602914975
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有