热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

deeplabv3架构_Deeplabv3+源码记录

以为自己会用的是pytorch架构,没有想到还是得先学会tensorflow啊!!!本来打算先读FastFCN的源码的&#x

以为自己会用的是pytorch架构,没有想到还是得先学会tensorflow啊!!!

本来打算先读FastFCN的源码的,但是因为刚出来可以学习的笔记太少,所以我还是选择了deeplabv3+源码学习!

大佬的github:https://codeload.github.com/rishizek/tensorflow-deeplab-v3-plus/zip/master

因为VOC2012的Augmented data一直下载不了,我选择用camvid进行训练,emmm,直接开始吧。(注意camvid的label是通过labelme进行图像标注,然后把得到的json转换为图片,最后把RGB图片再转换成单通道的灰度图片,数据集中的mask已经是单通道的灰度图片,故可以直接转为tfrecord数据格式)

1.TFRecord数据文件

Tensorflow拥有直接的数据输入的格式,所以在进行训练的第一步当然是如何把图片数据image和标签数据label转换成为TFrecored数据。下面介绍转换过程中常用的一些函数:

Img_raw=tf.gfile.FastGFile(dir,'rb').read() #dir具体到每张图片的地址,得到的图片类型是bytes,就不用转换类型了
# 变形记开始啦
Writer=tf.python_io.TFRecordWriter(output_dir) #output_dir是转换后文件地址,文件后缀可以直接是.tfrecord
# 将得到的图片数据转换成example protocol buffer
Example=tf.train.Example(features=tf.train.Features(feature={'img_raw':_bytes_feature(Img_raw)
}))
Writer.write(Example.SerializeToString()) #将信息写入这个数据结构
Writer.close()

以下是deeplabv3+中的数据转换,create_pascal_tf_record.py文件解读:

def create_tf_record(output_filename,image_dir,label_dir,examples):"""Creates a TFRecord file from examples.Args:output_filename: Path to where output file is saved.image_dir: Directory where image files are stored.label_dir: Directory where label files are stored.examples: Examples to parse and save to tf record."""# 创建一个类writerwriter = tf.python_io.TFRecordWriter(output_filename)for idx, example in enumerate(examples):if idx % 500 == 0:tf.logging.info('On image %d of %d', idx, len(examples))# 得到图片image和label的具体地址image_path = os.path.join(image_dir, example + '.png')label_path = os.path.join(label_dir, example + '.png')if not os.path.exists(image_path):tf.logging.warning('Could not find %s, ignoring example.', image_path)continueelif not os.path.exists(label_path):tf.logging.warning('Could not find %s, ignoring example.', label_path)continue# 将两个地址都送入dict_to_tf_example中,得到example对象try:tf_example = dict_to_tf_example(image_path, label_path)writer.write(tf_example.SerializeToString()) # 将对象写入文件地址中except ValueError:tf.logging.warning('Invalid example: %s, ignoring.', example)writer.close()# 将image和label都转换为TFrecord文件
def dict_to_tf_example(image_path, label_path):#以rb读二进制的方式打开图片所在地址with tf.gfile.GFile(image_path, 'rb') as fid:encoded_jpg = fid.read() # 读取图片内容encoded_jpg_io = io.BytesIO(encoded_jpg) # 创建一个临时二进制文件image = PIL.Image.open(encoded_jpg_io) # 读取这个二进制文件,但是返回的不是numpy数据,但可以通过numpy.array进行数据转换if image.format != 'png':raise ValueError('Image format not PNG')# 同上,对label数据进行转换。with tf.gfile.GFile(label_path, 'rb') as fid:encoded_label = fid.read()encoded_label_io = io.BytesIO(encoded_label)label = PIL.Image.open(encoded_label_io)if label.format != 'PNG':raise ValueError('Label format not PNG')if image.size != label.size:raise ValueError('The size of image does not match with that of label.')# 获得图片的宽高(480,360)width, height = image.size# 创建Example对象,每一个example都有下面这些feature# bytes_feature和int64_feature是TF数据的两种类型,即字符串型和int64example = tf.train.Example(features=tf.train.Features(feature={'image/height': dataset_util.int64_feature(height),'image/width': dataset_util.int64_feature(width),'image/encoded': dataset_util.bytes_feature(encoded_jpg),'image/format': dataset_util.bytes_feature('png'.encode('utf8')),'label/encoded': dataset_util.bytes_feature(encoded_label),'label/format': dataset_util.bytes_feature('png'.encode('utf8')),}))return example

2.train.py文件解读

def main(unused_argv):# Using the Winograd non-fused algorithms provides a small performance boost.os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'if FLAGS.clean_model_dir:shutil.rmtree(FLAGS.model_dir, ignore_errors=True)#1.创建RunConfig来更改checkpoint的时间run_config=tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9)#2.实例化Estimator,model_fn得到的是resnet-101的网络架构,model_dir是预训练好的模型参数,params参数会自动传送给deeplabv3_plus_model_fnmodel=tf.estimator.Estimator(model_fn=deeplab_model.deeplabv3_plus_model_fn, model_dir=FLAGS.model_dir,config=run_config, params={
'output_stride':FLAGS.output_stride,
'batch_size':FLAGS.batch_size,
'base_architecture':FLAGS.base_architecture,
'pre_trained_model':FLAGS.pre_trained_model,
'batch_norm_decay':_BATCH_NORM_DECAY,
'num_classes':_NUM_CLASSES,
'tensorboard_images_max_outputs':FLAGS.tensorboard_images_max_outputs,
'weight_decay':FLAGS.weight_decay,
'learning_rate_policy':FLAGS.learning_rate_policy,
'num_train':_NUM_IMAGES['train'],
'initial_learning_rate':FLAGS.initial_learning_rate,
'max_iter':FLAGS.max_iter,
'end_learning_rate':FLAGS.end_learning_rate,
'power':_POWER,
'momentum':_MOMENTUM,
'freeze_batch_norm':FLAGS.freeze_batch_norm,
'initial_global_step':FLAGS.initial_global_step
})
# 会被打印出来的内容包括学习率,交叉熵,像素点的准确度和mIOU
for_inrange(FLAGS.train_epochs//FLAGS.epochs_per_eval):tensors_to_log={'learning_rate':'learning_rate','cross_entropy':'cross_entropy','train_px_accuracy':'train_px_accuracy','train_mean_iou':'train_mean_iou',}
# 设置每迭代10次就会打印一次
logging_hook=tf.train.LoggingTensorHook(
tensors=tensors_to_log,every_n_iter=10)
train_hooks=[logging_hook]
eval_hooks=None
# 调用TF的调试器
if FLAGS.debug:debug_hook=tf_debug.LocalCLIDebugHook()train_hooks.append(debug_hook)eval_hooks=[debug_hook]# 3.开始训练模型,函数input_fn作为数据输入的来源
tf.logging.info("Starttraining.")
model.train(input_fn=lambda:input_fn(True,FLAGS.data_dir,FLAGS.batch_size,FLAGS.epochs_per_eval),hooks=train_hooks,)# 4.开始进行模型评估,函数input_fn作为数据输入来源
tf.logging.info("Startevaluation.")
eval_results=model.evaluate(input_fn=lambda:input_fn(False,FLAGS.data_dir,1),hooks=eval_hooks,)
print(eval_results)

3.input_fn()数据输入函数:

def input_fn(is_training,data_dir,batch_size,num_epochs=1):
# 1.切片处理,输入数据,在第一个维度内进行切片!!!!
dataset=tf.data.Dataset.from_tensor_slices(get_filenames(is_training,data_dir)) # 得到的是TFrecord的地址
# print(dataset.output_shapes)
#print(dataset.types) # 可以查看dataset里面的数据类型
dataset=dataset.flat_map(tf.data.TFRecordDataset) # 解析tfrecord文件的每一条记录,序列化后#tf.train.Example,解析函数是parse_record中的parse_single_example()# 2.shuffle()操作:随机化处理
ifis_training:dataset=dataset.shuffle(buffer_size=_NUM_IMAGES['train']) # 训练时,将输入数据的顺序进行打乱dataset=dataset.map(parse_record) # 指定parse_record方法对数据进行改变,返回的是image和label
dataset=dataset.map(lambdaimage,label:preprocess_image(image,label,is_training))
dataset=dataset.prefetch(batch_size)
# repeat操作
dataset=dataset.repeat(num_epochs) # 指定重复的次数
dataset=dataset.batch(batch_size)
# 3.创建迭代器
iterator=dataset.make_one_shot_iterator()
images,labels=iterator.get_next()
Return images,labels

4.parse_record()函数

下面介绍得到TFrecord之后对该文件的解析,即dataset.map(parserecord)中的parse_record()函数

def parse_record(raw_record):
keys_to_features={
'image/height':
tf.FixedLenFeature((),tf.int64),
'image/width':
tf.FixedLenFeature((),tf.int64),
'image/encoded':
tf.FixedLenFeature((),tf.string,default_value=''),
'image/format':
tf.FixedLenFeature((),tf.string,default_value='png'),
'label/encoded':
tf.FixedLenFeature((),tf.string,default_value=''),
'label/format':
tf.FixedLenFeature((),tf.string,default_value='png'),
}
# 解析每一条记录
parsed=tf.parse_single_example(raw_record,keys_to_features) # 解析record的每条记录#height=tf.cast(parsed['image/height'],tf.int32) # 进行类型转换!!!
#width=tf.cast(parsed['image/width'],tf.int32)image=tf.image.decode_image(tf.reshape(parsed['image/encoded'],shape=[]) ,_DEPTH)
image=tf.to_float(tf.image.convert_image_dtype(image,dtype=tf.uint8))
image.set_shape([None,None,3])label=tf.image.decode_image(
tf.reshape(parsed['label/encoded'],shape=[]),1)
label=tf.to_int32(tf.image.convert_image_dtype(label,dtype=tf.uint8))
label.set_shape([None,None,1])return image,label

在input_fn函数中,解析之后得到的文件还有进过再处理,即preprocess_image()函数:

def preprocess_image(image,label,is_training):
"""Preprocessa single image of layout [height,width,depth]."""
ifis_training:
#Randomly scale thei mage and label.
image,label=preprocessing.random_rescale_image_and_label(image,label,_MIN_SCALE,_MAX_SCALE)#Randomly crop or pad a [_HEIGHT,_WIDTH] section of the image and label.
image,label=preprocessing.random_crop_or_pad_image_and_label(
image,label,_HEIGHT,_WIDTH,_IGNORE_LABEL)#Randomly flip the image and label horizontally.
image,label=preprocessing.random_flip_left_right_image_and_label(
image,label)image.set_shape([_HEIGHT,_WIDTH,3])
label.set_shape([_HEIGHT,_WIDTH,1])image=preprocessing.mean_image_subtraction(image)return image,label

经过该函数可以知道,解析后的图片经过再次的处理之后图片的shape=[height,width,3],注意label.shape=[height,width,1]其中的1是它的通道数,因为label是单通道的灰度图,所以为1。

至此,数据的输入就讲完了,deeplabv3+的知识就讲完了,涉及到的argparse和tf.eatimator在onenote中进行过笔记处理,闲时再整理吧。



推荐阅读
  • 超级简单加解密工具的方案和功能
    本文介绍了一个超级简单的加解密工具的方案和功能。该工具可以读取文件头,并根据特定长度进行加密,加密后将加密部分写入源文件。同时,该工具也支持解密操作。加密和解密过程是可逆的。本文还提到了一些相关的功能和使用方法,并给出了Python代码示例。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • 怎么在PHP项目中实现一个HTTP断点续传功能发布时间:2021-01-1916:26:06来源:亿速云阅读:96作者:Le ... [详细]
  • 本文介绍了在处理不规则数据时如何使用Python自动提取文本中的时间日期,包括使用dateutil.parser模块统一日期字符串格式和使用datefinder模块提取日期。同时,还介绍了一段使用正则表达式的代码,可以支持中文日期和一些特殊的时间识别,例如'2012年12月12日'、'3小时前'、'在2012/12/13哈哈'等。 ... [详细]
  • Python爬虫中使用正则表达式的方法和注意事项
    本文介绍了在Python爬虫中使用正则表达式的方法和注意事项。首先解释了爬虫的四个主要步骤,并强调了正则表达式在数据处理中的重要性。然后详细介绍了正则表达式的概念和用法,包括检索、替换和过滤文本的功能。同时提到了re模块是Python内置的用于处理正则表达式的模块,并给出了使用正则表达式时需要注意的特殊字符转义和原始字符串的用法。通过本文的学习,读者可以掌握在Python爬虫中使用正则表达式的技巧和方法。 ... [详细]
  • IOS开发之短信发送与拨打电话的方法详解
    本文详细介绍了在IOS开发中实现短信发送和拨打电话的两种方式,一种是使用系统底层发送,虽然无法自定义短信内容和返回原应用,但是简单方便;另一种是使用第三方框架发送,需要导入MessageUI头文件,并遵守MFMessageComposeViewControllerDelegate协议,可以实现自定义短信内容和返回原应用的功能。 ... [详细]
  • 本文介绍了lua语言中闭包的特性及其在模式匹配、日期处理、编译和模块化等方面的应用。lua中的闭包是严格遵循词法定界的第一类值,函数可以作为变量自由传递,也可以作为参数传递给其他函数。这些特性使得lua语言具有极大的灵活性,为程序开发带来了便利。 ... [详细]
  • 使用Ubuntu中的Python获取浏览器历史记录原文: ... [详细]
  • C++字符字符串处理及字符集编码方案
    本文介绍了C++中字符字符串处理的问题,并详细解释了字符集编码方案,包括UNICODE、Windows apps采用的UTF-16编码、ASCII、SBCS和DBCS编码方案。同时说明了ANSI C标准和Windows中的字符/字符串数据类型实现。文章还提到了在编译时需要定义UNICODE宏以支持unicode编码,否则将使用windows code page编译。最后,给出了相关的头文件和数据类型定义。 ... [详细]
  • Go GUIlxn/walk 学习3.菜单栏和工具栏的具体实现
    本文介绍了使用Go语言的GUI库lxn/walk实现菜单栏和工具栏的具体方法,包括消息窗口的产生、文件放置动作响应和提示框的应用。部分代码来自上一篇博客和lxn/walk官方示例。文章提供了学习GUI开发的实际案例和代码示例。 ... [详细]
  • 先看官方文档TheJavaTutorialshavebeenwrittenforJDK8.Examplesandpracticesdescribedinthispagedontta ... [详细]
  • 这篇文章主要介绍了Python拼接字符串的七种方式,包括使用%、format()、join()、f-string等方法。每种方法都有其特点和限制,通过本文的介绍可以帮助读者更好地理解和运用字符串拼接的技巧。 ... [详细]
  • 模板引擎StringTemplate的使用方法和特点
    本文介绍了模板引擎StringTemplate的使用方法和特点,包括强制Model和View的分离、Lazy-Evaluation、Recursive enable等。同时,还介绍了StringTemplate语法中的属性和普通字符的使用方法,并提供了向模板填充属性的示例代码。 ... [详细]
  • 图像因存在错误而无法显示 ... [详细]
author-avatar
枫殇忆华年
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有