热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

去除0值和nan_【PyTorch】梯度爆炸、loss在反向传播变为nan

0.遇到大坑笔者在最近的项目中用到了自定义loss函数,代码一切都准备就绪后,在训练时遇到了梯度爆炸的问题,每次训练几个iteration

0. 遇到大坑

笔者在最近的项目中用到了自定义loss函数,代码一切都准备就绪后,在训练时遇到了梯度爆炸的问题,每次训练几个iterations后,梯度和loss都会变为nan。一般情况下,梯度变为nan都是出现了

,
等情况,导致结果变为+inf,也就成了nan。

1. 问题分析

笔者需要的loss函数如下:

其中,

从理论上分析,这个loss函数在反向传播过程中很可能会遇到梯度爆炸,这是为什么呢?反向传播的过程是对loss链式求一阶导数的过程,那么,

的导数为:

由于

,这个导数又可以表示为:

这样的话,出现了类似于

的表达式,也就会出现典型的$0/1$问题了。为了避免这个问题,首先进行了如下的
改变:

经过改变,在

时,不再是
问题了,而是转换为了一个线性函数,梯度成为了恒定的12.9,从理论上来看,避免了梯度爆炸的问题。

2. PyTorch初步实现

在实现这一过程时,依旧...遇到了大坑,下面通过示例代码来说明:

"""
loss = mse(X, gamma_inv(X))
"""
def loss_function(x):mask &#61; (x <0.003).float()gamma_x &#61; mask * 12.9 * x &#43; (1-mask) * (x ** 0.5)loss &#61; torch.mean((x - gamma_x) ** 2)return lossif __name__ &#61;&#61; &#39;__main__&#39;:x &#61; Variable(torch.FloatTensor([0, 0.0025, 0.5, 0.8, 1]), requires_grad&#61;True)loss &#61; loss_function(x)print(&#39;loss:&#39;, loss)loss.backward()print(x.grad)

改进后的

是一个分支结构&#xff0c;在实现时&#xff0c;就采用了类似于Matlab中矩阵计算的mask方式&#xff0c;mask定义为
&#xff0c;满足条件的$x_i$在mask中对应位置的值为1&#xff0c;因此&#xff0c;
的结构只会保留
的结果&#xff0c;同样的道理&#xff0c;
就实现了上述改进后的
公式。

按理来说&#xff0c;此时&#xff0c;在反向传播过程中的梯度应该是正确的&#xff0c;但是&#xff0c;上面代码的输出结果为&#xff1a;

loss: tensor(0.0105, grad_fn&#61;)
tensor([ nan, 0.1416, -0.0243, -0.0167, 0.0000])

emmm....依旧为nan&#xff0c;问题在理论层面得到了解决&#xff0c;但是&#xff0c;在实现层面依旧没能解决.....

3. 源码调试分析

上面源码的问题依旧在

的实现&#xff0c;这个过程&#xff0c;在Python解释器解释的过程或许是这样的&#xff1a;
  1. 计算
    &#xff0c;对mask进行广播式的乘法&#xff0c;结果为&#xff1a;原本为1的位置变为了12.9&#xff0c;原本为0的位置依旧为0&#xff1b;
  2. 将1.的结果继续与x相乘&#xff0c;本质上仍然是与x的每个元素相乘&#xff0c;只是mask中不满足条件的
    位置为0&#xff0c;表现出的结果是仅对满足条件的
    进行了计算&#xff1b;
  3. 按照2.所述的原理&#xff0c;
    公式的后半部分也是同样的计算过程&#xff0c;即&#xff0c;
    中的每个值依旧会进行
    的计算&#xff1b;

按照上述过程进行前向传播&#xff0c;在反向传播时&#xff0c;梯度不是从某一个分支得到的&#xff0c;而是两个分支的题目相加得到的&#xff0c;换句话说&#xff0c;依旧没能解决梯度变为nan的问题。

4. 源码改进及问题解决

经过第三部分的分析&#xff0c;知道了梯度变为nan的根本原因是当

时依旧参与了
的计算&#xff0c;导致在反向传播时计算出的梯度为nan。

要解决这个问题&#xff0c;就要保证在

时不会进行这样的计算。

新的PyTorch代码如下&#xff1a;

def loss_function(x):mask &#61; x <0.003gamma_x &#61; torch.FloatTensor(x.size()).type_as(x)gamma_x[mask] &#61; 12.9 * x[mask]mask &#61; x >&#61; 0.003gamma_x[mask] &#61; x[mask] ** 0.5loss &#61; torch.mean((x - gamma_x) ** 2)return lossif __name__ &#61;&#61; &#39;__main__&#39;:x &#61; Variable(torch.FloatTensor([0, 0.0025, 0.5, 0.8, 1]), requires_grad&#61;True)loss &#61; loss_function(x)print(&#39;loss:&#39;, loss)loss.backward()print(x.grad)

改变的地方位于&#96;loss_function&#96;&#xff0c;改变了对于

分支的处理方式&#xff0c;控制并保住每次计算仅有满足条件的值可以参与。此时输出为&#xff1a;

loss: tensor(0.0105, grad_fn&#61;)
tensor([ 0.0000, 0.1416, -0.0243, -0.0167, 0.0000])

就此&#xff0c;问题解决&#xff01;

如有疑问&#xff0c;欢迎留言~



推荐阅读
author-avatar
中二丶夜夜
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有