热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

基于TensorFlow中如何自定义梯度

这篇文章主要为大家展示了“基于TensorFlow中如何自定义梯度”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究

这篇文章主要为大家展示了“基于TensorFlow中如何自定义梯度”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“基于TensorFlow中如何自定义梯度”这篇文章吧。

前言

在深度学习中,有时候我们需要对某些节点的梯度进行一些定制,特别是该节点操作不可导(比如阶梯除法如 基于TensorFlow中如何自定义梯度 ),如果实在需要对这个节点进行操作,而且希望其可以反向传播,那么就需要对其进行自定义反向传播时的梯度。在有些场景,如[2]中介绍到的梯度反转(gradient inverse)中,就必须在某层节点对反向传播的梯度进行反转,也就是需要更改正常的梯度传播过程,如下图的 基于TensorFlow中如何自定义梯度 所示。

基于TensorFlow中如何自定义梯度

在tensorflow中有若干可以实现定制梯度的方法,这里介绍两种。

1. 重写梯度法

重写梯度法指的是通过tensorflow自带的机制,将某个节点的梯度重写(override),这种方法的适用性最广。我们这里举个例子[3].

符号函数的前向传播采用的是阶跃函数y=sign(x) y = \rm{sign}(x)y=sign(x),如下图所示,我们知道阶跃函数不是连续可导的,因此我们在反向传播时,将其替代为一个可以连续求导的函数y=Htanh(x) y = \rm{Htanh(x)}y=Htanh(x),于是梯度就是大于1和小于-1时为0,在-1和1之间时是1。

基于TensorFlow中如何自定义梯度

使用重写梯度的方法如下,主要是涉及到tf.RegisterGradient()和tf.get_default_graph().gradient_override_map(),前者注册新的梯度,后者重写图中具有名字name='Sign'的操作节点的梯度,用在新注册的QuantizeGrad替代。

#使用修饰器,建立梯度反向传播函数。其中op.input包含输入值、输出值,grad包含上层传来的梯度
@tf.RegisterGradient("QuantizeGrad")
def sign_grad(op, grad):
 input = op.inputs[0] # 取出当前的输入
 cond = (input>=-1)&(input<=1) # 大于1或者小于-1的值的位置
 zeros = tf.zeros_like(grad) # 定义出0矩阵用于掩膜
 return tf.where(cond, grad, zeros) 
 # 将大于1或者小于-1的上一层的梯度置为0
 
#使用with上下文管理器覆盖原始的sign梯度函数
def binary(input):
 x = input
 with tf.get_default_graph().gradient_override_map({"Sign":&#39;QuantizeGrad&#39;}):
 #重写梯度
  x = tf.sign(x)
 return x
 
#使用
x = binary(x)

其中的def sign_grad(op, grad):是注册新的梯度的套路,其中的op是当前操作的输入值/张量等,而grad指的是从反向而言的上一层的梯度。

通常来说,在tensorflow中自定义梯度,函数tf.identity()是很重要的,其API手册如下:

tf.identity(
 input,
 name=None
)

其会返回一个形状和内容都和输入完全一样的输出,但是你可以自定义其反向传播时的梯度,因此在梯度反转等操作中特别有用。

这里再举个反向梯度[2]的例子,也就是梯度为 基于TensorFlow中如何自定义梯度 而不是 基于TensorFlow中如何自定义梯度

import tensorflow as tf
x1 = tf.Variable(1)
x2 = tf.Variable(3)
x3 = tf.Variable(6)
@tf.RegisterGradient(&#39;CustomGrad&#39;)
def CustomGrad(op, grad):
#  tf.Print(grad)
 return -grad
 
g = tf.get_default_graph()
oo = x1+x2
with g.gradient_override_map({"Identity": "CustomGrad"}):
 output = tf.identity(oo)
grad_1 = tf.gradients(output, oo)
with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 print(sess.run(grad_1))

因为-grad,所以这里的梯度输出是[-1]而不是[1]。有一个我们需要注意的是,在自定义函数def CustomGrad()中,返回的值得是一个张量,而不能返回一个参数,比如return 0,这样会报错,如:

AttributeError: &#39;int&#39; object has no attribute &#39;name&#39;

显然,这是因为tensorflow的内部操作需要取返回值的名字而int类型没有名字。

PS:def CustomGrad()这个函数签名是随便你取的。

2. stop_gradient法

对于自定义梯度,还有一种比较简洁的操作,就是利用tf.stop_gradient()函数,我们看下例子[1]:

t = g(x)
y = t + tf.stop_gradient(f(x) - t)

这里,我们本来的前向传递函数是f(x),但是想要在反向时传递的函数是g(x),因为在前向过程中,tf.stop_gradient()不起作用,因此+t和-t抵消掉了,只剩下f(x)前向传递;而在反向过程中,因为tf.stop_gradient()的作用,使得f(x)-t的梯度变为了0,从而只剩下g(x)在反向传递。

我们看下完整的例子:

import tensorflow as tf

x1 = tf.Variable(1)
x2 = tf.Variable(3)
x3 = tf.Variable(6)

f = x1+x2*x3
t = -f

y1 = t + tf.stop_gradient(f-t)
y2 = f

grad_1 = tf.gradients(y1, x1)
grad_2 = tf.gradients(y2, x1)
with tf.Session(cOnfig=config) as sess:
 sess.run(tf.global_variables_initializer())

 print(sess.run(grad_1))
 print(sess.run(grad_2))

第一个输出为[-1],第二个输出为[1],显然也实现了梯度的反转。

以上是“基于TensorFlow中如何自定义梯度”这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注编程笔记行业资讯频道!


推荐阅读
  • Java容器中的compareto方法排序原理解析
    本文从源码解析Java容器中的compareto方法的排序原理,讲解了在使用数组存储数据时的限制以及存储效率的问题。同时提到了Redis的五大数据结构和list、set等知识点,回忆了作者大学时代的Java学习经历。文章以作者做的思维导图作为目录,展示了整个讲解过程。 ... [详细]
  • 如何自行分析定位SAP BSP错误
    The“BSPtag”Imentionedintheblogtitlemeansforexamplethetagchtmlb:configCelleratorbelowwhichi ... [详细]
  • YOLOv7基于自己的数据集从零构建模型完整训练、推理计算超详细教程
    本文介绍了关于人工智能、神经网络和深度学习的知识点,并提供了YOLOv7基于自己的数据集从零构建模型完整训练、推理计算的详细教程。文章还提到了郑州最低生活保障的话题。对于从事目标检测任务的人来说,YOLO是一个熟悉的模型。文章还提到了yolov4和yolov6的相关内容,以及选择模型的优化思路。 ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • 本文介绍了Redis的基础数据结构string的应用场景,并以面试的形式进行问答讲解,帮助读者更好地理解和应用Redis。同时,描述了一位面试者的心理状态和面试官的行为。 ... [详细]
  • 本文介绍了OC学习笔记中的@property和@synthesize,包括属性的定义和合成的使用方法。通过示例代码详细讲解了@property和@synthesize的作用和用法。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 本文详细介绍了Spring的JdbcTemplate的使用方法,包括执行存储过程、存储函数的call()方法,执行任何SQL语句的execute()方法,单个更新和批量更新的update()和batchUpdate()方法,以及单查和列表查询的query()和queryForXXX()方法。提供了经过测试的API供使用。 ... [详细]
  • web.py开发web 第八章 Formalchemy 服务端验证方法
    本文介绍了在web.py开发中使用Formalchemy进行服务端表单数据验证的方法。以User表单为例,详细说明了对各字段的验证要求,包括必填、长度限制、唯一性等。同时介绍了如何自定义验证方法来实现验证唯一性和两个密码是否相等的功能。该文提供了相关代码示例。 ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • 本文介绍了设计师伊振华受邀参与沈阳市智慧城市运行管理中心项目的整体设计,并以数字赋能和创新驱动高质量发展的理念,建设了集成、智慧、高效的一体化城市综合管理平台,促进了城市的数字化转型。该中心被称为当代城市的智能心脏,为沈阳市的智慧城市建设做出了重要贡献。 ... [详细]
  • 使用Ubuntu中的Python获取浏览器历史记录原文: ... [详细]
  • 本文介绍了Web学习历程记录中关于Tomcat的基本概念和配置。首先解释了Web静态Web资源和动态Web资源的概念,以及C/S架构和B/S架构的区别。然后介绍了常见的Web服务器,包括Weblogic、WebSphere和Tomcat。接着详细讲解了Tomcat的虚拟主机、web应用和虚拟路径映射的概念和配置过程。最后简要介绍了http协议的作用。本文内容详实,适合初学者了解Tomcat的基础知识。 ... [详细]
  • importjava.util.ArrayList;publicclassPageIndex{privateintpageSize;每页要显示的行privateintpageNum ... [详细]
author-avatar
爱你不愿放cwy
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有