热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PyTorch入坑十一:损失函数、正则化深刻剖析softmax+CrossEntropyLoss

这里写目录标题概念LossFunctionCostFunctionObjectiveFunction正则化损失函数交叉熵损失函数nn.CrossEntropyLoss()自信息熵(


这里写目录标题

  • 概念
    • Loss Function
    • Cost Function
    • Objective Function
  • 正则化
  • 损失函数
    • 交叉熵损失函数nn.CrossEntropyLoss()
      • 自信息
      • 熵(信息熵)
      • 相对熵(KL散度)
      • 交叉熵
        • 二分类
        • 多分类
        • 学习过程
        • 同MSE(Mean Squared Error)相比的优势
        • 工程实现中的问题与措施
        • softmax的缺点
        • PyTorch中 CEloss应用
      • PyTorch中的其他损失函数
        • nn.BCELoss
        • nn.BCEWithLogitsLoss


概念


Loss Function

计算一个样本的损失:
在这里插入图片描述


Cost Function

整个训练集(或者batch)的损失平均值
在这里插入图片描述


Objective Function

目标函数是一个更广泛的概念,在机器学习中,目标函数包含Cost和Regularization(正则项):
在这里插入图片描述


  • 正则化:惩罚较大的参数,参数的值越小,通常对应于越光滑的函数,也就是更加简单的函数。因此 就不易发生过拟合的问题。
  • 常用的有L1正则化和L2正则化
  • L1更适用于更适用于特征选择;L2更适用于防止模型过拟合
  • 更多细节参考 正则化的描述

正则化

方差的概念参考:方差

正则化策略的目的就是降低方差,减小过拟合的发生。

常用的手段有:L1正则化、L2正则化、Dropout、提前终止(早停)、数据扩增。
正则化这个话题比较大,待开一篇文章专门描述。


损失函数


交叉熵损失函数nn.CrossEntropyLoss()


  • 交叉熵损失函数常常用于分类任务
  • 交叉熵是衡量两个概率分布之间的差异。所以交叉熵值越低表示两个分布越近

自信息

自信息用于衡量单个事件的不确定性,其公式为:
在这里插入图片描述


熵(信息熵)

熵指的是信息熵,是自信息的期望。用来描述一个事件的不确定性,一个事件越不确定熵越大。熵是整个概率分布的不确定性,用来描述整个概率分布
在这里插入图片描述
伯努利分布的信息熵:
在这里插入图片描述
当事件的概率为0.5(如抛硬币)时,其信息熵最大,这也表示事件的不确定性最大,其熵最大值为0.69;如事件“明天太阳从东方生气”(概率极大),其信息熵比较小


相对熵(KL散度)

相对熵也称为KL散度,相对熵用于衡量两个分布之间的差异,也就是两个分布之间的距离,虽然相对熵可以计算两个分布之间的距离,但是相对熵不是一个距离函数,因为距离函数具有对称性,对称性指的是P到Q的距离等于Q到P的距离,但是相对熵不具备距离函数的对称性。


交叉熵

交叉熵、KL散度、信息熵的关系:在这里插入图片描述
公式中的P是真实的概率分布,也就是训练集中样本的分布,Q是模型输出的分布,因为训练集是固定的,所以H ( P ) 是一个常数,所以交叉熵在优化的时候是优化相对熵。
下面看两个交叉熵具体计算的例子:
参考交叉熵损失函数


二分类

在这里插入图片描述


多分类

在这里插入图片描述


学习过程

交叉熵损失函数经常用于分类问题中,特别是在神经网络做分类问题时,也经常使用交叉熵作为损失函数。此外,由于交叉熵计算中需要输入属于某一类的概率,所以交叉熵几乎每次都和sigmoid(或softmax)函数一起出现。
我们用神经网络最后一层输出的情况,来看一眼整个模型预测、获得损失和学习的流程:


  • 神经网络最后一层得到每个类别的得分scores
  • 该得分经过sigmoid(或softmax)函数获得概率输出;
  • 模型预测的类别概率输出与真实类别的one hot形式进行交叉熵损失函数的计算。

同MSE(Mean Squared Error)相比的优势

首先来看sigmoid+MSE的缺点:
在这里插入图片描述


  • 一句话总结:分类问题中,使用sigmoid/softmx得到概率,配合MSE损失函数时,采用梯度下降法进行学习时,会出现模型一开始训练时,学习速率非常慢的情况
  • 具体来说,在sigmoid层的的输入较大或者较小时,激活函数输出接近于1或者0,导致Loss相对于W的梯度接近0,学习困难。

sigmoid+ CELoss的优点
在这里插入图片描述
如公式所示,Loss关于最后一层的w梯度中,si表示sigmoid的输入,yi为label,xi为sigmoid之前的全连接层的输入。
xi -> 全链接层 -> si -> sigmoid层 -> CEloss层。
公式表明,当激活函数层的输出同label差异较大时,L关于w的梯度会较大,从而快速学习。[同生活中“因为明显的犯错可以快速地学习到正确的东西”比较一致]


工程实现中的问题与措施


  • softmax自身导致的数值问题
  • Softmax loss = softmax和交叉熵(cross-entropy loss)loss组合而成。
  • 所以全称是softmax with cross-entropy loss。
  • 在caffe,tensorflow等开源框架的实现中,直接将两者放在一个层中,而不是分开不同层,可以让数值计算更加稳定,因为正指数概率可能会有非常大的值
  • 参考 softmax数值稳定性 softmax数值稳定性2

在这里插入图片描述

在这里插入图片描述


softmax的缺点

前面说到,softmax一般配合CEloss一起使用。但是softmax这个操作具体什么含义呢。


  • softmax与hardmax

在CNN的分类问题中,我们的ground truth是one-hot形式,下面以四分类为例,理想输出应该是(1,0,0,0),或者说(100%,0%,0%,0%),这就是我们想让CNN学到的终极目标。
网络输出的幅值千差万别,输出最大的那一路对应的就是我们需要的分类结果。通常用百分比形式计算分类置信度,最简单的方式就是计算输出占比,这种最直接最最普通的方式,相对于soft的max,在这里我们把它叫做hard的max。
在这里插入图片描述
而现在通用的是soft的max,将每个输出x非线性放大到exp(x)
在这里插入图片描述
这样做有什么区别呢,看下面的例子:
在这里插入图片描述

相同输出特征情况,soft max比hard max更容易达到终极目标one-hot形式,或者说,softmax降低了训练难度,使得多分类问题更容易收敛。同时Softmax鼓励真实目标类别输出比其他类别要大,但并不要求大很多。对于人脸识别的特征映射(feature embedding)来说,Softmax鼓励不同类别的特征分开,但并不鼓励特征分离很多,如上表(5,1,1,1)时loss就已经很小了,此时CNN接近收敛梯度不再下降。


  • Softmax训练的深度特征,会把整个超空间或者超球,按照分类个数进行划分,保证类别是可分的,这一点对多分类任务如MNIST和ImageNet非常合适,因为测试类别必定在训练类别中。封闭集任务有效
  • Softmax并不要求类内紧凑和类间分离,这一点非常不适合人脸识别任务,因为训练集的1W人数,相对测试集整个世界70亿人类来说,非常微不足道,而我们不可能拿到所有人的训练样本,更过分的是,一般我们还要求训练集和测试集不重叠。
  • 所以需要改造Softmax,除了保证可分性外,还要做到特征向量类内尽可能紧凑,类间尽可能分离,常见的有L-softmax等

PyTorch中 CEloss应用

PyTorch中 CrossEntropyLoss 等价于 LogSoftmax + NLLLoss

CrossEntropyLoss 等价于 LogSoftmax + NLLLoss


PyTorch中的其他损失函数


nn.BCELoss

功能:二分类交叉熵;


nn.BCEWithLogitsLoss

BCEWithLogitsLoss就是把Sigmoid-BCELoss合成一步

更多loss参考PyTorch中更多loss说明


推荐阅读
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • 浏览器中的异常检测算法及其在深度学习中的应用
    本文介绍了在浏览器中进行异常检测的算法,包括统计学方法和机器学习方法,并探讨了异常检测在深度学习中的应用。异常检测在金融领域的信用卡欺诈、企业安全领域的非法入侵、IT运维中的设备维护时间点预测等方面具有广泛的应用。通过使用TensorFlow.js进行异常检测,可以实现对单变量和多变量异常的检测。统计学方法通过估计数据的分布概率来计算数据点的异常概率,而机器学习方法则通过训练数据来建立异常检测模型。 ... [详细]
  • cs231n Lecture 3 线性分类笔记(一)
    内容列表线性分类器简介线性评分函数阐明线性分类器损失函数多类SVMSoftmax分类器SVM和Softmax的比较基于Web的可交互线性分类器原型小结注:中文翻译 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • [译]技术公司十年经验的职场生涯回顾
    本文是一位在技术公司工作十年的职场人士对自己职业生涯的总结回顾。她的职业规划与众不同,令人深思又有趣。其中涉及到的内容有机器学习、创新创业以及引用了女性主义者在TED演讲中的部分讲义。文章表达了对职业生涯的愿望和希望,认为人类有能力不断改善自己。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • 本文介绍了机器学习手册中关于日期和时区操作的重要性以及其在实际应用中的作用。文章以一个故事为背景,描述了学童们面对老先生的教导时的反应,以及上官如在这个过程中的表现。同时,文章也提到了顾慎为对上官如的恨意以及他们之间的矛盾源于早年的结局。最后,文章强调了日期和时区操作在机器学习中的重要性,并指出了其在实际应用中的作用和意义。 ... [详细]
  • 本文介绍了绕过WAF的XSS检测机制的方法,包括确定payload结构、测试和混淆。同时提出了一种构建XSS payload的方法,该payload与安全机制使用的正则表达式不匹配。通过清理用户输入、转义输出、使用文档对象模型(DOM)接收器和源、实施适当的跨域资源共享(CORS)策略和其他安全策略,可以有效阻止XSS漏洞。但是,WAF或自定义过滤器仍然被广泛使用来增加安全性。本文的方法可以绕过这种安全机制,构建与正则表达式不匹配的XSS payload。 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
  • 背景应用安全领域,各类攻击长久以来都危害着互联网上的应用,在web应用安全风险中,各类注入、跨站等攻击仍然占据着较前的位置。WAF(Web应用防火墙)正是为防御和阻断这类攻击而存在 ... [详细]
  • plt python 画直线_机器学习干货,一步一步通过Python实现梯度下降的学习
    GradientDescent-梯度下降梯度下降法(英语:Gradientdescent)是一个一阶最优化算法,通常也称为最速下降法。要使用梯度下降法找 ... [详细]
  • 2017亚马逊人工智能奖公布:他们的AI有什么不同?
    事实上,在我们周围,“人工智能”让一切都变得更“智能”极具讽刺意味。随着人类与机器智能之间的界限变得模糊,我们的世界正在变成一个机器 ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 展开全部下面的代码是创建一个立方体Thisexamplescreatesanddisplaysasimplebox.#Thefirstlineloadstheinit_disp ... [详细]
  • Learning to Paint with Model-based Deep Reinforcement Learning
    本文介绍了一种基于模型的深度强化学习方法,通过结合神经渲染器,教机器像人类画家一样进行绘画。该方法能够生成笔画的坐标点、半径、透明度、颜色值等,以生成类似于给定目标图像的绘画。文章还讨论了该方法面临的挑战,包括绘制纹理丰富的图像等。通过对比实验的结果,作者证明了基于模型的深度强化学习方法相对于基于模型的DDPG和模型无关的DDPG方法的优势。该研究对于深度强化学习在绘画领域的应用具有重要意义。 ... [详细]
author-avatar
飞天美术_888_265
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有