热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Tensorflow保存恢复模型及微调

使用tensorflow的过程中,我们常常会用到训练好的模型。我们可以直接使用训练好的模型进行测试或者对训练好的模型做进一步的微调。(微调是指初始化网络参数的时候不再是随机初始化,

使用tensorflow的过程中,我们常常会用到训练好的模型。我们可以直接使用训练好的模型进行测试或者对训练好的模型做进一步的微调。(微调是指初始化网络参数的时候不再是随机初始化,而是使用先前训练好的权重参数进行初始化,在此基础上对网络的全部或者局部参数进行重新训练的过程)。为了实现模型的复用或微调,我将从以下四个方面进行说明:

  • 模型是指什么?
  • 如何保存模型?
  • 如何恢复模型?
  • 如何进行微调?

一、模型是指什么?

tensorflow训练后需要保存的模型主要包含两部分,一是网络图,二是网络图里的参数值。保存的模型文件结构如下(假设每过1000次保存一次):

checkpoint
MyModel-1000.meta
MyModel-1000.data-00000-of-00001
MyModel-1000.index
MyModel-2000.meta
MyModel-2000.data-00000-of-00001
MyModel-2000.index
MyModel-3000.meta
MyModel-3000.data-00000-of-00001
MyModel-3000.index
.......

1 checkpoint

checkpoint是一个文本文件,如下所示。其中有model_checkpoint_path和all_model_checkpoint_paths两个属性。model_checkpoint_path保存了最新的tensorflow模型文件的文件名,all_model_checkpoint_paths则有未被删除的所有tensorflow模型文件的文件名。

model_checkpoint_path: "MyModel-3000"
all_model_checkpoint_paths: "MyModel-1000"
all_model_checkpoint_paths: "MyModel-2000"
all_model_checkpoint_paths: "MyModel-3000"
......

2 .meta文件

.meta 文件用于保存网络结构,且以 protocol buffer 格式进行保存。protocol buffer是Google 公司内部使用的一种轻便高效的数据描述语言。类似于XML能够将结构化数据序列化,protocol buffer也可用于序列化结构化数据,并将其用于数据存储、通信协议等方面。相较于XML,protocol buffer更小、更快、也更简单。

3 .data-00000-of-00001 文件和 .index 文件

在tensorflow 0.11之前,保存的文件结构如下。tensorflow 0.11之后,将ckpt文件拆分为了.data-00000-of-00001 和 .index 两个文件。.ckpt是二进制文件,保存了所有变量的值及变量的名称。拆分后的.data-00000-of-00001 保存的是变量值,.index文件保存的是.data文件中数据和 .meta文件中结构图之间的对应关系(也就是变量的名称)

checkpoint
MyModel.meta
MyModel.ckpt

二、如何保存模型?

tensorflow 提供tf.train.Saver类及tf.train.Saver类下面的save方法共同保存模型。下面分别说明tf.train.Saver类及save方法:

tf.train.Saver(var_list=None, reshape=False, sharded=False, max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False,
saver_def=None, builder=None, defer_build=False, allow_empty=False,
write_version=saver_pb2.SaverDef.V2, pad_step_number=False)
就常用的参数进行说明:
var_list:如果我们不对tf.train.Saver指定任何参数,默认会保存所有变量。如果你只想保存一部分变量,
可以通过将需要保存的变量构造list或者dictionary,赋值给var_list。
max_to_keep:tensorflow默认只会保存最近的5个模型文件,如果你希望保存更多,可以通过max_to_keep来指定
keep_checkpoint_every_n_hours:设置每隔几小时保存一次模型
save(sess,save_path,global_step=None,latest_filename=None,meta_graph_suffix="meta",
write_meta_graph=True, write_state=True)
就常用的参数进行说明:
sess:在tensorflow中,只有开启session时数据才会流动,因此保存模型的时候必须传入session。
save_path: 模型保存的路径及模型名称。
global_step:定义每隔多少步保存一次模型,每次会在保存的模型名称后面加上global_step的值作为后缀
write_meta_graph:布尔值,True表示每次都保存图,False表示不保存图(由于图是不变的,没必要每次都去保存)
注意:保存变量的时候必须在session中;保存的变量必须已经初始化;

1.简单示例

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
w3 = tf.Variable(tf.random_normal(shape=[1]), name='w3')
saver = tf.train.Saver()#未指定任何参数,默认保存所有变量。等价于saver = tf.train.Saver(tf.trainable_variables())
save_path = './checkpoint_dir/MyModel'#定义模型保存的路径./checkpoint_dir/及模型名称MyModel
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, save_path)

执行后,在checkpoint_dir目录下创建模型文件如下:

checkpoint
MyModel.data-00000-of-00001
MyModel.index
MyModel.meta

2.经典示例

import tensorflow as tf
from six.moves import xrange
import os
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w11')#变量w1在内存中的名字是w11;恢复变量时应该与name的名字保持一致
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w22')
w3 = tf.Variable(tf.random_normal(shape=[5]), name='w33')
#保存一部分变量[w1,w2];只保存最近的5个模型文件;每2小时保存一次模型
saver = tf.train.Saver([w1, w2],max_to_keep=5, keep_checkpoint_every_n_hours=2)
save_path = './checkpoint_dir/MyModel'#定义模型保存的路径./checkpoint_dir/及模型名称MyModel
# Launch the graph and train, saving the model every 1,000 steps.
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in xrange(100):
if step % 10 == 0:
# 每隔step=10步保存一次模型( keep_checkpoint_every_n_hours与global_step可同时使用,表示'与',通常任选一个就够了);
#每次会在保存的模型名称后面加上global_step的值作为后缀
# write_meta_graph=False表示不保存图
saver.save(sess, save_path, global_step=step, write_meta_graph=False)
# 如果模型文件中没有保存网络图,则使用如下语句保存一张网络图(由于网络图不变,只保存一次就行)
if not os.path.exists('./checkpoint_dir/MyModel.meta'):
# saver.export_meta_graph(filename=None, collection_list=None,as_text=False,export_scope=None,clear_devices=False)
# saver.export_meta_graph()仅仅保存网络图;参数filename表示网络图保存的路径即网络图名称
saver.export_meta_graph('./checkpoint_dir/MyModel.meta')#定义网络图保存的路径./checkpoint_dir/及网络图名称MyModel.meta
#注意:tf.train.export_meta_graph()等价于tf.train.Saver.export_meta_graph()

执行后,在checkpoint_dir目录下创建模型文件如下:

checkpoint
MyModel.meta
MyModel-50.data-00000-of-00001
MyModel-50.index
MyModel-60.data-00000-of-00001
MyModel-60.index
MyModel-70.data-00000-of-00001
MyModel-70.index
MyModel-80.data-00000-of-00001
MyModel-80.index
MyModel-90.data-00000-of-00001
MyModel-90.index

三、如何恢复模型?

tensorflow保存模型时将网络图和网络图里的参数值分开保存。因此,在恢复模型时,也要分为2步:构造网络图和加载参数。

1 构造网络图

构造网络图可以手动创建(需要创建一个跟保存的模型一模一样的网络图)

也可以从meta文件里加载graph进行创建,如下:

#首先恢复graph
saver = tf.train.import_meta_graph('./checkpoint_dir/MyModel.meta')

2 恢复参数有两种方式,如下:

with tf.Session() as sess:
#恢复最新保存的权重
saver.restore(sess, tf.train.latest_checkpoint('./checkpoint_dir'))
#指定一个权重恢复
saver.restore(sess, './checkpoint_dir/MyModel-50')#注意不要加文件后缀名。若权重保存为.ckpt则需要加上后缀

四、如何进行微调?

上面叙述了如何恢复模型,那么,对于恢复出来的模型应该如何使用呢?这里以tensorflow官网给出的vgg为例进行说明。下载地址

恢复出来的模型有四种用途:

  • 查看模型参数
  • 直接使用原始模型进行测试
  • 扩展原始模型(直接使用扩展后的网络进行测试,扩展后需要重新训练的情况见微调部分)
  • 微调:使用训练好的权重参数进行初始化,在此基础上对网络的全部或局部参数进行重新训练

1.查看模型参数

import tensorflow as tf
import vgg
# build graph
graph = tf.Graph
inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')
net, end_points = vgg.vgg_16(inputs, num_classes=1000)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './vgg_16.ckpt') # 权重保存为.ckpt则需要加上后缀
""" 查看恢复的模型参数 tf.trainable_variables()查看的是所有可训练的变量; tf.global_variables()获得的与tf.trainable_variables()类似,只是多了一些非trainable的变量,比如定义时指定为trainable=False的变量; sess.graph.get_operations()则可以获得几乎所有的operations相关的tensor """
tvs = [v for v in tf.trainable_variables()]
print('获得所有可训练变量的权重:')
for v in tvs:
print(v.name)
print(sess.run(v))

gv = [v for v in tf.global_variables()]
print('获得所有变量:')
for v in gv:
print(v.name, '\n')

# sess.graph.get_operations()可以换为tf.get_default_graph().get_operations()
ops = [o for o in sess.graph.get_operations()]
print('获得所有operations相关的tensor:')
for o in ops:
print(o.name, '\n')

2.直接使用原始模型进行测试

import tensorflow as tf
import vgg
import numpy as np
import cv2
image = cv2.imread('./cat.18.jpg')
print(image.shape)
res = cv2.resize(image, (224,224))
res_image = np.expand_dims(res, 0)
print(res_image.shape, type(res_image))
#build graph
graph = tf.Graph
inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')
net, end_points = vgg.vgg_16(inputs, num_classes=1000)
print(end_points)
saver = tf.train.Saver()
with tf.Session() as sess:
#恢复权重
saver.restore(sess, './vgg_16.ckpt')#权重保存为.ckpt则需要加上后缀

# Get input and output tensors
# 需要特别注意,get_tensor_by_name后面传入的参数,如果没有重复,需要在后面加上“:0”
# sess.graph等价于tf.get_default_graph()
input = sess.graph.get_tensor_by_name('inputs:0')
output = sess.graph.get_tensor_by_name('vgg_16/fc8/squeezed:0')

# Run forward pass to calculate pred
#使用不同的数据运行相同的网络,只需将新数据通过feed_dict传递到网络即可。
pred = sess.run(output, feed_dict={input:res_image})
#得到使用vgg网络对输入图片的分类结果
print(np.argmax(pred, 1))

3.扩展原始模型

import tensorflow as tf
import vgg
import numpy as np
import cv2
image = cv2.imread('./cat.18.jpg')
print(image.shape)
res = cv2.resize(image, (224, 224))
res_image = np.expand_dims(res, 0)
print(res_image.shape, type(res_image))
# build graph
graph = tf.Graph
inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')
net, end_points = vgg.vgg_16(inputs, num_classes=1000)
print(end_points)
saver = tf.train.Saver()
with tf.Session() as sess:
# 恢复权重
saver.restore(sess, './vgg_16.ckpt') # 权重保存为.ckpt则需要加上后缀

# 明确的网络的输入输出,通过get_tensor_by_name()获取变量
input = sess.graph.get_tensor_by_name('inputs:0')
output = sess.graph.get_tensor_by_name('vgg_16/fc8/squeezed:0')

# add more operations to the graph
# 这里只是简单示例,也可以加上新的网络层。
pred = tf.argmax(output, 1)

# 使用不同的数据运行扩展后的网络(这里扩展后的网络不涉及变量,可以直接使用扩展后的网络进行测试)
pred = sess.run(pred, feed_dict={input: res_image})
print(pred)

4.微调

变量ensorflow as tf
import vgg
import numpy as np
import cv2
from skimage import io
import os
# -----------------------------------------准备数据--------------------------------------
#这里以单张图片作为示例,简单说明原理
image = cv2.imread('./cat.18.jpg')
print(image.shape)
res_image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_CUBIC)#vgg_16有全连接层,需要固定输入尺寸
print(res_image.shape)
res_image = np.expand_dims(res_image, axis=0)#网络输入为四维[batch_size, height, width, channels]
print(res_image.shape)
labels = [[1,0]]#标签
# -----------------------------------------恢复图------------------------------------------
#恢复图的方式有很多,这里采用手动构造一个跟保存权重时一样的graph
graph = tf.get_default_graph()
input = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')
y_ = tf.placeholder(dtype=tf.float32, shape=[None, 2], name='labels')
# net=[batch, 2]其中2表示二分类,注意官网给出的vgg_16最终的输出没有经过softmax层
net, end_points = vgg.vgg_16(input, num_classes=2) # 保存的权重模型针对的num_classes=1000,这里改为num_classes=2,因此最后一层需要重新训练
print(net, end_points) # net是网络的输出;end_points是所有变量的集合
#add more operations to the graph
y = tf.nn.softmax(net) # 输出0-1之间的概率
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
output_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='vgg_16/fc8') # 注意这里的scope是定义graph时 name_scope的名字,不要加:0
print(output_vars)
# loss只作用在var_list列表中的变量,也就是说只训练var_list中的变量,其余变量保持不变。若不指定var_list,则默认重新训练所有变量
train_op = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy,var_list=output_vars)
# ----------------------------------------恢复权重------------------------------------------
var = tf.global_variables() # 获取所有变量
print(var)
# var_to_restore = [val for val in var if 'conv1' in val.name or 'conv2' in val.name]#保留变量中含有conv1、conv2的变量
var_to_restore = [val for val in var if 'fc8' not in val.name] # 保留变量名中不含有fc8的变量
print(var_to_restore)
saver = tf.train.Saver(var_to_restore) # 恢复var_to_restore列表中的变量(最后一层变量fc8不恢复)
with tf.Session() as sess:
# restore恢复变量值也是变量初始化的一种方式,对于没有restore的变量需要单独初始化
# 注意如果使用全局初始化,则应在全局初始化后再调用saver.restore()。相当于先通过全局初始化赋值,再通过restore重新赋值。
saver.restore(sess, './vgg_16.ckpt') # 权重保存为.ckpt则需要加上后缀
var_to_init = [val for val in var if 'fc8' in val.name] # 保留变量名中含有fc8的变量
# tf.variable_initializers(tf.global_variables())等价于tf.global_variables_initializer()
sess.run(tf.variables_initializer(var_to_init)) # 没有restore的变量需要单独初始化
# sess.run(tf.global_variables_initializer())
# 用w1,w8测试权重恢复成功没有.正确的情况应该是:w1的值不变,w8的值随机
w1 = sess.graph.get_tensor_by_name('vgg_16/conv1/conv1_1/weights:0')
print(sess.run(w1, feed_dict={input: res_image}))
w8 = sess.graph.get_tensor_by_name('vgg_16/fc8/weights:0')
print('w8', sess.run(w8, feed_dict={input: res_image}))

sess.run(train_op, feed_dict={input:res_image, y_:labels})

五、补充

1 .pb格式的文件

上面提到对于恢复的模型可以直接用来进行测试。对于不再需要改动的模型,我们可以将其保存为.pb格式的文件。

为什么要生成pb文件呢?简单来说就是直接通过tf.saver保存的模型文件其参数和图是分开的。这种形式方便对程序进行微小的改动。但是对于训练好,以后不再需要改动的模型这种形式就不是很必要了。

pb文件就是将变量的值固定下来,直接“烧”到图里面。这个时候只需用户提供一个输入,我们就可以通过模型得到一个输出给用户。pb文件一方面可提供给用户做离线的预测;另一方面,对于线上的模型,一般是通过C++或者C语言编写的程序进行调用。所以模型最终都是写成pb格式的文件。

2 .npy格式的文件

tensorflow保存的模型文件只能在tensorflow框架下使用,不利于将模型权重导入到其他框架使用,同时保存的模型文件无法直接查看。因此经常会考虑转换为.npy格式。.npy文件里的权重值是以数组的形式保存着的,方便查看。

参考:

A quick complete tutorial to save and restore Tensorflow models – CV-Tricks.com

月夜 – 分享网络知识 · 享受快乐生活


推荐阅读
  • IjustinheritedsomewebpageswhichusesMooTools.IneverusedMooTools.NowIneedtoaddsomef ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • 欢乐的票圈重构之旅——RecyclerView的头尾布局增加
    项目重构的Git地址:https:github.comrazerdpFriendCircletreemain-dev项目同步更新的文集:http:www.jianshu.comno ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 在Docker中,将主机目录挂载到容器中作为volume使用时,常常会遇到文件权限问题。这是因为容器内外的UID不同所导致的。本文介绍了解决这个问题的方法,包括使用gosu和suexec工具以及在Dockerfile中配置volume的权限。通过这些方法,可以避免在使用Docker时出现无写权限的情况。 ... [详细]
  • eclipse学习(第三章:ssh中的Hibernate)——11.Hibernate的缓存(2级缓存,get和load)
    本文介绍了eclipse学习中的第三章内容,主要讲解了ssh中的Hibernate的缓存,包括2级缓存和get方法、load方法的区别。文章还涉及了项目实践和相关知识点的讲解。 ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
  • Oracle分析函数first_value()和last_value()的用法及原理
    本文介绍了Oracle分析函数first_value()和last_value()的用法和原理,以及在查询销售记录日期和部门中的应用。通过示例和解释,详细说明了first_value()和last_value()的功能和不同之处。同时,对于last_value()的结果出现不一样的情况进行了解释,并提供了理解last_value()默认统计范围的方法。该文对于使用Oracle分析函数的开发人员和数据库管理员具有参考价值。 ... [详细]
  • XML介绍与使用的概述及标签规则
    本文介绍了XML的基本概念和用途,包括XML的可扩展性和标签的自定义特性。同时还详细解释了XML标签的规则,包括标签的尖括号和合法标识符的组成,标签必须成对出现的原则以及特殊标签的使用方法。通过本文的阅读,读者可以对XML的基本知识有一个全面的了解。 ... [详细]
  • 本文介绍了游标的使用方法,并以一个水果供应商数据库为例进行了说明。首先创建了一个名为fruits的表,包含了水果的id、供应商id、名称和价格等字段。然后使用游标查询了水果的名称和价格,并将结果输出。最后对游标进行了关闭操作。通过本文可以了解到游标在数据库操作中的应用。 ... [详细]
  • Python瓦片图下载、合并、绘图、标记的代码示例
    本文提供了Python瓦片图下载、合并、绘图、标记的代码示例,包括下载代码、多线程下载、图像处理等功能。通过参考geoserver,使用PIL、cv2、numpy、gdal、osr等库实现了瓦片图的下载、合并、绘图和标记功能。代码示例详细介绍了各个功能的实现方法,供读者参考使用。 ... [详细]
  • 本文讨论了如何使用IF函数从基于有限输入列表的有限输出列表中获取输出,并提出了是否有更快/更有效的执行代码的方法。作者希望了解是否有办法缩短代码,并从自我开发的角度来看是否有更好的方法。提供的代码可以按原样工作,但作者想知道是否有更好的方法来执行这样的任务。 ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • Day2列表、字典、集合操作详解
    本文详细介绍了列表、字典、集合的操作方法,包括定义列表、访问列表元素、字符串操作、字典操作、集合操作、文件操作、字符编码与转码等内容。内容详实,适合初学者参考。 ... [详细]
  • Imtryingtofigureoutawaytogeneratetorrentfilesfromabucket,usingtheAWSSDKforGo.我正 ... [详细]
author-avatar
赖雨蓉744_128
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有