热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

tf.estimatorAPI技术手册(13)——LinearRegressor(线性回归器)

tf.estimatorAPI技术手册(12)——LinearRegressor(线性分类器)(一࿰

tf.estimator API技术手册(12)——LinearRegressor(线性分类器)

  • (一)简 介
  • (二)初始化
  • (三)属 性(Properties)
  • (四)方 法(Methods)
    • (1)evaluate(评估)
    • (2)predict(预测)
    • (3)train(训练)


(一)简 介

继承自Estimator,定义在tensorflow/python/estimator/canned/linear.py中,用来建立线性回归器。示例如下:

categorical_column_a = categorical_column_with_hash_bucket(...)
categorical_column_b = categorical_column_with_hash_bucket(...)categorical_feature_a_x_categorical_feature_b = crossed_column(...)# 创建一个线性分类器
estimator = LinearRegressor(feature_columns=[categorical_column_a,categorical_feature_a_x_categorical_feature_b])# 构建数据输入函数
def input_fn_train: # returns x, y (where y represents label's class index)....
def input_fn_eval: # returns x, y (where y represents label's class index)....
estimator.train(input_fn=input_fn_train)
estimator.evaluate(input_fn=input_fn_eval)
estimator.predict(input_fn=input_fn_predict)

(二)初始化

__init__(feature_columns,model_dir=None,label_dimension=1,weight_column=None,optimizer='Ftrl',config=None,partitioner=None,warm_start_from=None,loss_reduction=losses.Reduction.SUM,sparse_combiner='sum'
)

参数如下:

  • feature_columns:
    特征列

  • model_dir:
    保存模型的目录

  • label_dimension:
    每一个样例中回归目标的个数,这是标签的最后维度和logit张量对象的大小,典型的结构为[批处理大小,标签维度]。

  • weight_column:
    由 tf.feature_column.numeric_column创建的一个字符串或者数字列用来呈现特征列。它将会被乘以example的训练损失。

  • label_vocabulary:
    一个字符串列表用来呈现可能的标签取值,如果给出,则必须为字符型,如果没有给出,则会默认编码为整型,为{0, 1,…, n_classes-1} 。

  • optimizer:
    选择优化器,默认使用Adagrad optimizer,激活函数为tf.nn.relu。

  • config:
    一个运行配置对象,用来配置运行时间。

  • warm_start_from:
    热启动选项。

  • loss_reduction:
    定义损失函数,默认为SUM方法。

  • partitioner:
    可选的。输入层分割器

  • sparse_combiner:
    稀疏组合器,当一个绝对特征列是多价的时候,指定如何去减少损失。


(三)属 性(Properties)
  • config
  • model_dir
  • model_fn
    Returns the model_fn which is bound to self.params.

返回:
model_fn 附有以下标记: def model_fn(features, labels, mode, config)

(四)方 法(Methods)

(1)evaluate(评估)

evaluate(input_fn,steps=None,hooks=None,checkpoint_path=None,name=None
)

评估函数,使用input_fn给出的评估数据评估训练好的模型,参数列表如下:

  • input_fn:
    一个用来构造用于评估的数据的函数,这个函数应该构造和返回如下的值:一个tf.data.Dataset对象或者一个包含 (features, labels)的元组,它们应当满足model_fn函数对输入数据的要求,在后面的实例中我们会详细介绍。
  • checkpoint_path:
    用来保存训练好的模型
  • name:
    如果用户需要在不同的数据集上运行多个评价,如训练集和测试集,则为要进行评估的名称,不同的评估度量被保存在单独的文件夹中,并分别出现在tensorboard中。

(2)predict(预测)

predict(input_fn,predict_keys=None,hooks=None,checkpoint_path=None,yield_single_examples=True
)

使用训练好的模型对新实例进行预测,以下为参数列表:

  • input_fn:
    一个用来构造用于评估的数据的函数,这个函数应该构造和返回如下的值:一个tf.data.Dataset对象或者一个包含 (features, labels)的元组,它们应当满足model_fn函数对输入数据的要求,在后面的实例中我们会详细介绍。

  • predict_keys:
    预测函数最终会返回一系列的结果,但我们可以有选择地让其输出,可供选择的keys列表为[‘logits’, ‘logistic’, ‘probabilities’, ‘class_ids’, ‘classes’],如果不指定的话,默认返回所有值。

  • hooks:
    tf.train.SessionRunHook的子类实例列表,在预测调用中用于传回。

  • checkpoint_path:
    训练好的模型的目录

  • yield_single_examples:
    可以选择False或是True,如果选择False,由model_fn返回整个批次,而不是将批次分解为单个元素。当model_fn返回的一些的张量的第一维度和批处理数量不相等时,这个功能是很用的。


(3)train(训练)

train(input_fn,hooks=None,steps=None,max_steps=None,saving_listeners=None
)

用于训练模型,以下为参数列表:

  • input_fn:
    一个用来构造用于评估的数据的函数,这个函数应该构造和返回如下的值:一个tf.data.Dataset对象或者一个包含 (features, labels)的元组,它们应当满足model_fn函数对输入数据的要求,在后面的实例中我们会详细介绍。

  • hooks:
    tf.train.SessionRunHook的子类实例列表,在预测调用中用于传回。

  • steps:
    模型训练的次数,如果不指定,则会一直训练知道input_fn传回的数据消耗完为止。如果你不想要增量表现,就设置max_steps来替代,注意设置了steps,max_steps必须为None,设置了max_steps,steps必须为None。

  • max_steps:
    模型训练的总次数,注意设置了steps,max_steps必须为None,设置了max_steps,steps必须为None。

  • saving_listeners:
    CheckpointSaverListener对象的列表,用于在检查点保存之前或之后立即运行的回调。


推荐阅读
  • 开源Keras Faster RCNN模型介绍及代码结构解析
    本文介绍了开源Keras Faster RCNN模型的环境需求和代码结构,包括FasterRCNN源码解析、RPN与classifier定义、data_generators.py文件的功能以及损失计算。同时提供了该模型的开源地址和安装所需的库。 ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • Centos7.6安装Gitlab教程及注意事项
    本文介绍了在Centos7.6系统下安装Gitlab的详细教程,并提供了一些注意事项。教程包括查看系统版本、安装必要的软件包、配置防火墙等步骤。同时,还强调了使用阿里云服务器时的特殊配置需求,以及建议至少4GB的可用RAM来运行GitLab。 ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • Oracle seg,V$TEMPSEG_USAGE与Oracle排序的关系及使用方法
    本文介绍了Oracle seg,V$TEMPSEG_USAGE与Oracle排序之间的关系,V$TEMPSEG_USAGE是V_$SORT_USAGE的同义词,通过查询dba_objects和dba_synonyms视图可以了解到它们的详细信息。同时,还探讨了V$TEMPSEG_USAGE的使用方法。 ... [详细]
  • web.py开发web 第八章 Formalchemy 服务端验证方法
    本文介绍了在web.py开发中使用Formalchemy进行服务端表单数据验证的方法。以User表单为例,详细说明了对各字段的验证要求,包括必填、长度限制、唯一性等。同时介绍了如何自定义验证方法来实现验证唯一性和两个密码是否相等的功能。该文提供了相关代码示例。 ... [详细]
  • 模板引擎StringTemplate的使用方法和特点
    本文介绍了模板引擎StringTemplate的使用方法和特点,包括强制Model和View的分离、Lazy-Evaluation、Recursive enable等。同时,还介绍了StringTemplate语法中的属性和普通字符的使用方法,并提供了向模板填充属性的示例代码。 ... [详细]
  • 本文介绍了一个适用于PHP应用快速接入TRX和TRC20数字资产的开发包,该开发包支持使用自有Tron区块链节点的应用场景,也支持基于Tron官方公共API服务的轻量级部署场景。提供的功能包括生成地址、验证地址、查询余额、交易转账、查询最新区块和查询交易信息等。详细信息可参考tron-php的Github地址:https://github.com/Fenguoz/tron-php。 ... [详细]
  • GreenDAO快速入门
    前言之前在自己做项目的时候,用到了GreenDAO数据库,其实对于数据库辅助工具库从OrmLite,到litePal再到GreenDAO,总是在不停的切换,但是没有真正去了解他们的 ... [详细]
  • AFNetwork框架(零)使用NSURLSession进行网络请求
    本文介绍了AFNetwork框架中使用NSURLSession进行网络请求的方法,包括NSURLSession的配置、请求的创建和执行等步骤。同时还介绍了NSURLSessionDelegate和NSURLSessionConfiguration的相关内容。通过本文可以了解到AFNetwork框架中使用NSURLSession进行网络请求的基本流程和注意事项。 ... [详细]
  • 本文讨论了在使用Git进行版本控制时,如何提供类似CVS中自动增加版本号的功能。作者介绍了Git中的其他版本表示方式,如git describe命令,并提供了使用这些表示方式来确定文件更新情况的示例。此外,文章还介绍了启用$Id:$功能的方法,并讨论了一些开发者在使用Git时的需求和使用场景。 ... [详细]
  • Struts2+Sring+Hibernate简单配置
    2019独角兽企业重金招聘Python工程师标准Struts2SpringHibernate搭建全解!Struts2SpringHibernate是J2EE的最 ... [详细]
author-avatar
饱和深潜者_463
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有