热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

javadataset读取数据,TensorFlow读写数据

前言回顾前面:众所周知,要训练出一个模型,首先我们得有数据。我们第一个例子中,直接使用dataset的api去加载mnist

前言

回顾前面:

众所周知,要训练出一个模型,首先我们得有数据。我们第一个例子中,直接使用dataset的api去加载mnist的数据。(minst的数据要么我们是提前下载好,放在对应的目录上,要么就根据他给的url直接从网上下载)。

一般来说,我们使用TensorFlow是从TFRecord文件中读取数据的。

TFRecord 文件格式是一种面向记录的简单二进制格式,很多 TensorFlow 应用采用此格式来训练数据

所以,这篇文章来聊聊怎么读取TFRecord文件的数据。

一、入门对数据集的数据进行读和写

首先,我们来体验一下怎么造一个TFRecord文件,怎么从TFRecord文件中读取数据,遍历(消费)这些数据。

1.1 造一个TFRecord文件

现在,我们还没有TFRecord文件,我们可以自己简单写一个:

def write_sample_to_tfrecord():

gmv_values = np.arange(10)

click_values = np.arange(10)

label_values = np.arange(10)

with tf.python_io.TFRecordWriter("/Users/zhongfucheng/data/fashin/demo.tfrecord", options=None) as writer:

for _ in range(10):

feature_internal = {

"gmv": tf.train.Feature(float_list=tf.train.FloatList(value=[gmv_values[_]])),

"click": tf.train.Feature(int64_list=tf.train.Int64List(value=[click_values[_]])),

"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label_values[_]]))

}

features_extern = tf.train.Features(feature=feature_internal)

# 使用tf.train.Example将features编码数据封装成特定的PB协议格式

# example = tf.train.Example(features=tf.train.Features(feature=features_extern))

example = tf.train.Example(features=features_extern)

# 将example数据系列化为字符串

example_str = example.SerializeToString()

# 将系列化为字符串的example数据写入协议缓冲区

writer.write(example_str)

if __name__ == '__main__':

write_sample_to_tfrecord()

我相信大家代码应该是能够看得懂的,其实就是分了几步:

生成TFRecord Writer

tf.train.Feature生成协议信息

使用tf.train.Example将features编码数据封装成特定的PB协议格式

将example数据系列化为字符串

将系列化为字符串的example数据写入协议缓冲区

参考资料:

ok,现在我们就有了一个TFRecord文件啦。

1.2 读取TFRecord文件

其实就是通过tf.data.TFRecordDataset这个api来读取到TFRecord文件,生成处dataset对象

对dataset进行处理(shape处理,格式处理...等等)

使用迭代器对dataset进行消费(遍历)

demo代码如下:

import tensorflow as tf

def read_tensorflow_tfrecord_files():

# 定义消费缓冲区协议的parser,作为dataset.map()方法中传入的lambda:

def _parse_function(single_sample):

features = {

"gmv": tf.FixedLenFeature([1], tf.float32),

"click": tf.FixedLenFeature([1], tf.int64), # ()或者[]没啥影响

"label": tf.FixedLenFeature([1], tf.int64)

}

parsed_features = tf.parse_single_example(single_sample, features=features)

# 对parsed 之后的值进行cast.

gmv = tf.cast(parsed_features["gmv"], tf.float64)

click = tf.cast(parsed_features["click"], tf.float64)

label = tf.cast(parsed_features["label"], tf.float64)

return gmv, click, label

# 开始定义dataset以及解析tfrecord格式

filenames = tf.placeholder(tf.string, shape=[None])

# 定义dataset 和 一些列trasformation method

dataset = tf.data.TFRecordDataset(filenames)

parsed_dataset = dataset.map(_parse_function) # 消费缓冲区需要定义在dataset 的map 函数中

batchd_dataset = parsed_dataset.batch(3)

# 创建Iterator

sample_iter = batchd_dataset.make_initializable_iterator()

# 获取next_sample

gmv, click, label = sample_iter.get_next()

training_filenames = [

"/Users/zhongfucheng/data/fashin/demo.tfrecord"]

with tf.Session() as session:

# 初始化带参数的Iterator

session.run(sample_iter.initializer, feed_dict={filenames: training_filenames})

# 读取文件

print(session.run(gmv))

if __name__ == '__main__':

read_tensorflow_tfrecord_files()

无意外的话,我们可以输出这样的结果:

[[0.]

[1.]

[2.]]

ok,现在我们已经大概知道怎么写一个TFRecord文件,以及怎么读取TFRecord文件的数据,并且消费这些数据了。

二、epoch和batchSize术语解释

我在学习TensorFlow翻阅资料时,经常看到一些机器学习的术语,由于自己没啥机器学习的基础,所以很多时候看到一些专业名词就开始懵逼了。

2.1epoch

当一个完整的数据集通过了神经网络一次并且返回了一次,这个过程称为一个epoch。

这可能使我们跟dataset.repeat()方法联系起来,这个方法可以使当前数据集重复一遍。比如说,原有的数据集是[1,2,3,4,5],如果我调用dataset.repeat(2)的话,那么我们的数据集就变成了[1,2,3,4,5],[1,2,3,4,5]

所以会有个说法:假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch

2.2batchSize

一般来说我们的数据集都是比较大的,无法一次性将整个数据集的数据喂进神经网络中,所以我们会将数据集分成好几个部分。每次喂多少条样本进神经网络,这个叫做batchSize。

在TensorFlow也提供了方法给我们设置:dataset.batch(),在API中是这样介绍batchSize的:

representing the number of consecutive elements of this dataset to combine in a single batch

我们一般在每次训练之前,会将整个数据集的顺序打乱,提高我们模型训练的效果。这里我们用到的api是:dataset.shffle();

三、再来聊聊dataset

我从官网的介绍中截了一个dataset的方法图(部分):

1460000018530101?w=332&h=988

dataset的功能主要有以下三种:

创建dataset实例

通过文件创建(比如TFRecord)

通过内存创建

对数据集的数据进行变换

比如上面的batch(),常见的map(),flat_map(),zip(),repeat()等等

文档中一般都有给出例子,跑一下一般就知道对应的意思了。

创建迭代器,遍历数据集的数据

3.1 聊聊迭代器

迭代器可以分为四种:

单次。对数据集进行一次迭代,不支持参数化

可初始化迭代

使用前需要进行初始化,支持传入参数。面向的是同一个DataSet

可重新初始化:同一个Iterator从不同的DataSet中读取数据

DataSet的对象具有相同的结构,可以使用tf.data.Iterator.from_structure来进行初始化

问题:每次 Iterator 切换时,数据都从头开始打印了

可馈送(也是通过对象相同的结果来创建的迭代器)

可让您在两个数据集之间切换的可馈送迭代器

通过一个string handler来实现。

可馈送的 Iterator 在不同的 Iterator 切换的时候,可以做到不从头开始。

简单总结:

1、 单次 Iterator ,它最简单,但无法重用,无法处理数据集参数化的要求。

2、 可以初始化的 Iterator ,它可以满足 Dataset 重复加载数据,满足了参数化要求。

3、可重新初始化的 Iterator,它可以对接不同的 Dataset,也就是可以从不同的 Dataset 中读取数据。

4、可馈送的 Iterator,它可以通过 feeding 的方式,让程序在运行时候选择正确的 Iterator,它和可重新初始化的 Iterator 不同的地方就是它的数据在不同的 Iterator 切换时,可以做到不重头开始读取数据。

string handler(可馈送的 Iterator)这种方式是最常使用的,我当时也写了一个Demo来使用了一下,代码如下:

def read_tensorflow_tfrecord_files():

# 开始定义dataset以及解析tfrecord格式.

train_filenames = tf.placeholder(tf.string, shape=[None])

vali_filenames = tf.placeholder(tf.string, shape=[None])

# 加载train_dataset batch_inputs这个方法每个人都不一样的,这个方法我就不给了。

train_dataset = batch_inputs([

train_filenames], batch_size=5, type=False,

num_epochs=2, num_preprocess_threads=3)

# 加载validation_dataset batch_inputs这个方法每个人都不一样的,这个方法我就不给了。

validation_dataset = batch_inputs([vali_filenames

], batch_size=5, type=False,

num_epochs=2, num_preprocess_threads=3)

# 创建出string_handler()的迭代器(通过相同数据结构的dataset来构建)

handle = tf.placeholder(tf.string, shape=[])

iterator = tf.data.Iterator.from_string_handle(

handle, train_dataset.output_types, train_dataset.output_shapes)

# 有了迭代器就可以调用next方法了。

itemid = iterator.get_next()

# 指定哪种具体的迭代器,有单次迭代的,有初始化的。

training_iterator = train_dataset.make_initializable_iterator()

validation_iterator = validation_dataset.make_initializable_iterator()

# 定义出placeholder的值

training_filenames = [

"/Users/zhongfucheng/tfrecord_test/data01aa"]

validation_filenames = ["/Users/zhongfucheng/tfrecord_validation/part-r-00766"]

with tf.Session() as sess:

# 初始化迭代器

training_handle = sess.run(training_iterator.string_handle())

validation_handle = sess.run(validation_iterator.string_handle())

for _ in range(2):

sess.run(training_iterator.initializer, feed_dict={train_filenames: training_filenames})

print("this is training iterator ----")

for _ in range(5):

print(sess.run(itemid, feed_dict={handle: training_handle}))

sess.run(validation_iterator.initializer,

feed_dict={vali_filenames: validation_filenames})

print("this is validation iterator ")

for _ in range(5):

print(sess.run(itemid, feed_dict={vali_filenames: validation_filenames, handle: validation_handle}))

if __name__ == '__main__':

read_tensorflow_tfrecord_files()

参考资料:

3.2 dataset参考资料

在翻阅资料时,发现写得不错的一些博客:

最后

乐于输出干货的Java技术公众号:Java3y。公众号内有200多篇原创技术文章、海量视频资源、精美脑图,不妨来关注一下!

下一篇文章打算讲讲如何理解axis~

1460000018530102?w=258&h=258

觉得我的文章写得不错,不妨点一下赞!



推荐阅读
  • 本文介绍了Python对Excel文件的读取方法,包括模块的安装和使用。通过安装xlrd、xlwt、xlutils、pyExcelerator等模块,可以实现对Excel文件的读取和处理。具体的读取方法包括打开excel文件、抓取所有sheet的名称、定位到指定的表单等。本文提供了两种定位表单的方式,并给出了相应的代码示例。 ... [详细]
  • 展开全部下面的代码是创建一个立方体Thisexamplescreatesanddisplaysasimplebox.#Thefirstlineloadstheinit_disp ... [详细]
  • Go GUIlxn/walk 学习3.菜单栏和工具栏的具体实现
    本文介绍了使用Go语言的GUI库lxn/walk实现菜单栏和工具栏的具体方法,包括消息窗口的产生、文件放置动作响应和提示框的应用。部分代码来自上一篇博客和lxn/walk官方示例。文章提供了学习GUI开发的实际案例和代码示例。 ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • IOS开发之短信发送与拨打电话的方法详解
    本文详细介绍了在IOS开发中实现短信发送和拨打电话的两种方式,一种是使用系统底层发送,虽然无法自定义短信内容和返回原应用,但是简单方便;另一种是使用第三方框架发送,需要导入MessageUI头文件,并遵守MFMessageComposeViewControllerDelegate协议,可以实现自定义短信内容和返回原应用的功能。 ... [详细]
  • EzPP 0.2发布,新增YAML布局渲染功能
    EzPP发布了0.2.1版本,新增了YAML布局渲染功能,可以将YAML文件渲染为图片,并且可以复用YAML作为模版,通过传递不同参数生成不同的图片。这个功能可以用于绘制Logo、封面或其他图片,让用户不需要安装或卸载Photoshop。文章还提供了一个入门例子,介绍了使用ezpp的基本渲染方法,以及如何使用canvas、text类元素、自定义字体等。 ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • 本文讨论了为什么在main.js中写import不会全局生效的问题,并提供了解决方案。在每一个vue文件中都需要写import语句才能使其生效,而在main.js中写import语句则不会全局生效。本文还介绍了使用Swal和sweetalert2库的示例。 ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • Day2列表、字典、集合操作详解
    本文详细介绍了列表、字典、集合的操作方法,包括定义列表、访问列表元素、字符串操作、字典操作、集合操作、文件操作、字符编码与转码等内容。内容详实,适合初学者参考。 ... [详细]
  • 基于dlib的人脸68特征点提取(眨眼张嘴检测)python版本
    文章目录引言开发环境和库流程设计张嘴和闭眼的检测引言(1)利用Dlib官方训练好的模型“shape_predictor_68_face_landmarks.dat”进行68个点标定 ... [详细]
  • Python使用Pillow包生成验证码图片的方法
    本文介绍了使用Python中的Pillow包生成验证码图片的方法。通过随机生成数字和符号,并添加干扰象素,生成一幅验证码图片。需要配置好Python环境,并安装Pillow库。代码实现包括导入Pillow包和随机模块,定义随机生成字母、数字和字体颜色的函数。 ... [详细]
author-avatar
手机用户2502905627_315
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有