一般,我们对整个训练过程有两种方案
1. 构建数据库的部分是独立,也就是说我们对找到的图像做预处理,将预处理的结果保存起来,这就算构建好训练的数据库了,然后训练时从这数据库里直接拿数据进行训练
2. 训练时实时地预处理一个batch的图像,将处理的结果作为训练的输入
第一种方法将训练集的构建和网络的训练分开,并且将预处理结果存在电脑中,这样做训练的代码会比较简单,且直接读入处理好的数据能让训练速度更快,当然,不足的地方就是不够灵活,如果预处理的方式改变了(例如,原本是RGB转HSV,现在我要RGB转Gray),那么需要重新构建一个数据库,造成硬盘空间的浪费
第二种方法虽然训练速度不如第一种,但是足够灵活,我们主要关注第二种方法。
在TensorFlow中,图像处理主要由tf.image
模块支持,batch generator主要用tf.data.Dataset
实现,下面我们来看看整个流程的具体实现
1 获取所有图片的路径
很明显,如果训练集很大,图片很多,我们无法一次读取所有图片进行训练,因此我们先找到所有图片的路径,在需要读取图片时再根据路径读取图片
import glob
images_dir = '/home/public/butterfly/dataset_detection/JPEGImages/'
images_paths = glob.glob(images_dir+'*.jpg')
images_paths += glob.glob(images_dir+'*.jpeg?s=#39;)
images_paths += glob.glob(images_dir+'*.png')
print('Find {} images, the first 10 image paths are:'.format(len(images_paths)))
for path in images_paths[:10]:print(path)
Find 717 images, the first 10 image paths are:
/home/public/butterfly/dataset_detection/JPEGImages/IMG_001000.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_000969.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_000805.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_000158.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_001017.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_001155.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_001404.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_000202.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_000568.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_000022.jpg
import numpy as np
test_split_factor = 0.2
n_test_path = int(len(images_paths)*test_split_factor)
train_image_paths = np.asarray(images_paths[:-n_test_path])
test_image_paths = np.asarray(images_paths[-n_test_path:])
print('Number of train set is {}'.format(train_image_paths.shape[0]))
print('Number of test set is {}'.format(test_image_paths.shape[0]))
Number of train set is 574
Number of test set is 143
2. Batch Generator
我们将使用tf.data.Dataset
来实现batch generator,这里借鉴了一篇博客 TensorFlow全新的数据读取方式:Dataset API入门教程。我们直接上代码,具体解释请看注释
def gaussian_noise_layer(input_image, std):noise = tf.random_normal(shape=tf.shape(input_image), mean=0.0, stddev=std, dtype=tf.float32)noise_image = tf.cast(input_image, tf.float32) + noisenoise_image = tf.clip_by_value(noise_image, 0, 1.0)return noise_imagedef parse_data(filename):'''导入数据,进行预处理,输出两张图像,分别是输入图像和目标图像(例如,在图像去噪中,输入的是一张带噪声图像,目标图像是无噪声图像)Args:filaneme, 图片的路径Returns:输入图像,目标图像'''image = tf.read_file(filename)image = tf.image.decode_image(image)image = tf.random_crop(image, size=(100,100, 3))image = tf.image.random_flip_left_right(image)image = tf.cast(image, tf.float32) / 255.0n_image =gaussian_noise_layer(image, 0.5)return n_image, image
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
def train_generator(batchsize, shuffle=True):'''生成器,用于生产训练数据Args:batchsize,训练的batch sizeshuffle, 是否随机打乱batchReturns:训练需要的数据'''with tf.Session() as sess:train_dataset = tf.data.Dataset().from_tensor_slices((train_image_paths))train_dataset = train_dataset.map(parse_data)train_dataset = train_dataset.batch(batchsize)train_dataset = train_dataset.repeat()if shuffle:train_dataset = train_dataset.shuffle(buffer_size=4)train_iterator = train_dataset.make_initializable_iterator()sess.run(train_iterator.initializer)train_batch = train_iterator.get_next()while True:try:x_batch, y_batch = sess.run(train_batch)yield (x_batch, y_batch)except:train_iterator = train_dataset.make_initializable_iterator()sess.run(train_iterator.initializer)train_batch = train_iterator.get_next()x_batch, y_batch = sess.run(train_batch)yield (x_batch, y_batch)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
import matplotlib.pyplot as plt
%matplotlib inline
def view_samples(samples, nrows, ncols, figsize=(5,5)):fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, sharey=True, sharex=True)for ax, img in zip(axes.flatten(), samples):ax.axis('off')ax.set_adjustable('box-forced')im = ax.imshow(img, aspect='equal')plt.subplots_adjust(wspace=0, hspace=0)plt.show()return fig, axes
train_gen = train_generator(16)iteration = 5
for i in range(iteration): noise_x, x = next(train_gen)_ = view_samples(noise_x, 4,4)_ = view_samples(x, 4, 4)
总结
TensorFlow提供了一整套图像预处理以及数据生成的机制,我们实现了一个简单的常用的数据处理框架,总结为三步
1. 获取所有图片的路径
2. 写好预处理的代码(parse_data)
3. 定义好数据生成器
基于以上的流程,稍微加以修改就能够应对大部分训练要求