热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Tensorflow图像处理以及数据读取

本文完整代码在https:github.comjiemojiemoTensorflow_Demoblobmasterimage_loader.ipynbTensorflow图像处理

本文完整代码在 https://github.com/jiemojiemo/Tensorflow_Demo/blob/master/image_loader.ipynb



Tensorflow图像处理以及数据读取

本人研究的方向是图像处理,这个领域几乎被深度学习的方法给统治了,例如图像去燥、图像超分辨、图像增强等等。在模拟实现相关论文的方法时,我发现最难的部分并不是深度学习的网络,而是如何构建你的训练集。通常,为了构建一个可训练的数据库我需要做:


  1. 上网找到论文提到的图像数据库,或者自己收集图像
  2. 对图像进行处理,构建训练所需的数据库,处理的方式各种各样,包括


  • 图像颜色域的变化,例如RGB转HSV,RGB转Gray等等
  • 图像大小的缩放,例如将不同大小的图像缩放为大小一致的图像
  • 提前图像块(image patch),就是从整张图像中,扣下小块(例如32*32)的小图像,这么做主要是因为可以增加训练数据的量,另外小块的图像训练起来速度更快,image patch的合理性是基于卷积神经网络的感受野(Receptive field)通常不会太大
  • 数据增强(Data augmentation),例如将图像上下翻转,左右翻转,裁剪,旋转等等。这里有一篇Keras-5 基于 ImageDataGenerator 的 Data Augmentation实现可以让大家大致明白什么是Data augmentation

  • 为了训练,给深度网络喂数据,我还需要写batch generator,就是用来生成一个batch的东西

  • 一般,我们对整个训练过程有两种方案
    1. 构建数据库的部分是独立,也就是说我们对找到的图像做预处理,将预处理的结果保存起来,这就算构建好训练的数据库了,然后训练时从这数据库里直接拿数据进行训练
    2. 训练时实时地预处理一个batch的图像,将处理的结果作为训练的输入

    第一种方法将训练集的构建和网络的训练分开,并且将预处理结果存在电脑中,这样做训练的代码会比较简单,且直接读入处理好的数据能让训练速度更快,当然,不足的地方就是不够灵活,如果预处理的方式改变了(例如,原本是RGB转HSV,现在我要RGB转Gray),那么需要重新构建一个数据库,造成硬盘空间的浪费

    第二种方法虽然训练速度不如第一种,但是足够灵活,我们主要关注第二种方法。

    在TensorFlow中,图像处理主要由tf.image模块支持,batch generator主要用tf.data.Dataset实现,下面我们来看看整个流程的具体实现


    1 获取所有图片的路径

    很明显,如果训练集很大,图片很多,我们无法一次读取所有图片进行训练,因此我们先找到所有图片的路径,在需要读取图片时再根据路径读取图片

    import glob
    # images_dir 下存放着需要预处理的图像
    images_dir = '/home/public/butterfly/dataset_detection/JPEGImages/'# 查找图片文件, 根据具体数据集自由添加各种图片格式(jpg, jpeg, png, bmp等等)
    images_paths = glob.glob(images_dir+'*.jpg')
    images_paths += glob.glob(images_dir+'*.jpeg?s=#39;)
    images_paths += glob.glob(images_dir+'*.png')
    print('Find {} images, the first 10 image paths are:'.format(len(images_paths)))
    for path in images_paths[:10]:print(path)

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    Find 717 images, the first 10 image paths are:
    /home/public/butterfly/dataset_detection/JPEGImages/IMG_001000.jpg
    /home/public/butterfly/dataset_detection/JPEGImages/IMG_000969.jpg
    /home/public/butterfly/dataset_detection/JPEGImages/IMG_000805.jpg
    /home/public/butterfly/dataset_detection/JPEGImages/IMG_000158.jpg
    /home/public/butterfly/dataset_detection/JPEGImages/IMG_001017.jpg
    /home/public/butterfly/dataset_detection/JPEGImages/IMG_001155.jpg
    /home/public/butterfly/dataset_detection/JPEGImages/IMG_001404.jpg
    /home/public/butterfly/dataset_detection/JPEGImages/IMG_000202.jpg
    /home/public/butterfly/dataset_detection/JPEGImages/IMG_000568.jpg
    /home/public/butterfly/dataset_detection/JPEGImages/IMG_000022.jpg

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    import numpy as np
    # split training set and test data
    test_split_factor = 0.2
    n_test_path = int(len(images_paths)*test_split_factor)
    # 转出numpy数据,方便使用
    train_image_paths = np.asarray(images_paths[:-n_test_path])
    test_image_paths = np.asarray(images_paths[-n_test_path:])
    print('Number of train set is {}'.format(train_image_paths.shape[0]))
    print('Number of test set is {}'.format(test_image_paths.shape[0]))

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    Number of train set is 574
    Number of test set is 143

    • 1
    • 2

    2. Batch Generator

    我们将使用tf.data.Dataset来实现batch generator,这里借鉴了一篇博客 TensorFlow全新的数据读取方式:Dataset API入门教程。我们直接上代码,具体解释请看注释

    def gaussian_noise_layer(input_image, std):noise = tf.random_normal(shape=tf.shape(input_image), mean=0.0, stddev=std, dtype=tf.float32)noise_image = tf.cast(input_image, tf.float32) + noisenoise_image = tf.clip_by_value(noise_image, 0, 1.0)return noise_imagedef parse_data(filename):'''导入数据,进行预处理,输出两张图像,分别是输入图像和目标图像(例如,在图像去噪中,输入的是一张带噪声图像,目标图像是无噪声图像)Args:filaneme, 图片的路径Returns:输入图像,目标图像'''# 读取图像image = tf.read_file(filename)# 解码图片image = tf.image.decode_image(image)# 数据预处理,或者数据增强,这一步根据需要自由发挥# 随机提取patchimage = tf.random_crop(image, size=(100,100, 3))# 数据增强,随机水平翻转图像image = tf.image.random_flip_left_right(image)# 图像归一化image = tf.cast(image, tf.float32) / 255.0# 加噪声n_image =gaussian_noise_layer(image, 0.5)return n_image, image

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

    def train_generator(batchsize, shuffle=True):'''生成器,用于生产训练数据Args:batchsize,训练的batch sizeshuffle, 是否随机打乱batchReturns:训练需要的数据'''with tf.Session() as sess:# 创建数据库train_dataset = tf.data.Dataset().from_tensor_slices((train_image_paths))# 预处理数据train_dataset = train_dataset.map(parse_data)# 设置 batch sizetrain_dataset = train_dataset.batch(batchsize)# 无限重复数据train_dataset = train_dataset.repeat()# 洗牌,打乱if shuffle:train_dataset = train_dataset.shuffle(buffer_size=4)# 创建迭代器train_iterator = train_dataset.make_initializable_iterator()sess.run(train_iterator.initializer)train_batch = train_iterator.get_next()# 开始生成数据while True:try:x_batch, y_batch = sess.run(train_batch)yield (x_batch, y_batch)except:# 如果没有 train_dataset = train_dataset.repeat()# 数据遍历完就到end了,就会抛出异常train_iterator = train_dataset.make_initializable_iterator()sess.run(train_iterator.initializer)train_batch = train_iterator.get_next()x_batch, y_batch = sess.run(train_batch)yield (x_batch, y_batch)

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42

    import matplotlib.pyplot as plt
    %matplotlib inline
    #%config InlineBackend.figure_format='retina'# 显示图像
    def view_samples(samples, nrows, ncols, figsize=(5,5)):fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, sharey=True, sharex=True)for ax, img in zip(axes.flatten(), samples):ax.axis('off')ax.set_adjustable('box-forced')im = ax.imshow(img, aspect='equal')plt.subplots_adjust(wspace=0, hspace=0)plt.show()return fig, axes

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    # 测试一下我们的代码
    train_gen = train_generator(16)iteration = 5
    for i in range(iteration): noise_x, x = next(train_gen)_ = view_samples(noise_x, 4,4)_ = view_samples(x, 4, 4)

    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    png

    png


    总结

    TensorFlow提供了一整套图像预处理以及数据生成的机制,我们实现了一个简单的常用的数据处理框架,总结为三步
    1. 获取所有图片的路径
    2. 写好预处理的代码(parse_data)
    3. 定义好数据生成器
    基于以上的流程,稍微加以修改就能够应对大部分训练要求


推荐阅读
  • 基于TensorFlow的Keras高级API实现手写体数字识别
    前言这个项目的话我也是偶然在B站看到一个阿婆主(SvePana)在讲解这个,跟着他的视频敲的代码并学习起来的。并写在自己这里做个笔记也为 ... [详细]
  • 本文介绍了在开发Android新闻App时,搭建本地服务器的步骤。通过使用XAMPP软件,可以一键式搭建起开发环境,包括Apache、MySQL、PHP、PERL。在本地服务器上新建数据库和表,并设置相应的属性。最后,给出了创建new表的SQL语句。这个教程适合初学者参考。 ... [详细]
  • 本文介绍了Java工具类库Hutool,该工具包封装了对文件、流、加密解密、转码、正则、线程、XML等JDK方法的封装,并提供了各种Util工具类。同时,还介绍了Hutool的组件,包括动态代理、布隆过滤、缓存、定时任务等功能。该工具包可以简化Java代码,提高开发效率。 ... [详细]
  • [译]技术公司十年经验的职场生涯回顾
    本文是一位在技术公司工作十年的职场人士对自己职业生涯的总结回顾。她的职业规划与众不同,令人深思又有趣。其中涉及到的内容有机器学习、创新创业以及引用了女性主义者在TED演讲中的部分讲义。文章表达了对职业生涯的愿望和希望,认为人类有能力不断改善自己。 ... [详细]
  • CF:3D City Model(小思维)问题解析和代码实现
    本文通过解析CF:3D City Model问题,介绍了问题的背景和要求,并给出了相应的代码实现。该问题涉及到在一个矩形的网格上建造城市的情景,每个网格单元可以作为建筑的基础,建筑由多个立方体叠加而成。文章详细讲解了问题的解决思路,并给出了相应的代码实现供读者参考。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
  • 解决Cydia数据库错误:could not open file /var/lib/dpkg/status 的方法
    本文介绍了解决iOS系统中Cydia数据库错误的方法。通过使用苹果电脑上的Impactor工具和NewTerm软件,以及ifunbox工具和终端命令,可以解决该问题。具体步骤包括下载所需工具、连接手机到电脑、安装NewTerm、下载ifunbox并注册Dropbox账号、下载并解压lib.zip文件、将lib文件夹拖入Books文件夹中,并将lib文件夹拷贝到/var/目录下。以上方法适用于已经越狱且出现Cydia数据库错误的iPhone手机。 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • 闭包一直是Java社区中争论不断的话题,很多语言都支持闭包这个语言特性,闭包定义了一个依赖于外部环境的自由变量的函数,这个函数能够访问外部环境的变量。本文以JavaScript的一个闭包为例,介绍了闭包的定义和特性。 ... [详细]
  • Html5-Canvas实现简易的抽奖转盘效果
    本文介绍了如何使用Html5和Canvas标签来实现简易的抽奖转盘效果,同时使用了jQueryRotate.js旋转插件。文章中给出了主要的html和css代码,并展示了实现的基本效果。 ... [详细]
author-avatar
safadfdfdsfsd
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有