热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

TensorFlow实战系列5梯度下降算法

 本文将介绍优化训练神经网络模型的一些常用方法,并给出使用TensorFlow 实现深度学习的最佳实践样例代码。为了更好的介绍优化神经网络训练过程,我们将首先介绍优化神经网络的算法——梯度下降算法。然后在后面的部分中,我们将围绕该算法中的一些元素来优化模型训练过程。

梯度下降算法
 梯度下降算法主要用于优化单个参数的取值,而反向传播算法给出了一个高效的方式在所有参数上使用梯度下降算法,从而使神经网络模型在训练数据上的损失函数尽可能小。反向传播算法是训练神经网络的核心算法,它可以根据定义好的损失函数优化神经网络中参数的取值,从而使神经网络模型在训练数据集上的损失函数达到一个较小值。神经网络模型中参数的优化过程直接决定了模型的质量,是使用神经网络时非常重要的一步。假设用Θ 表示神经网络中的参数,J(Θ) 表示在给定的参数取值下,训练数据上损失函数的大小,那么整个优化过程可以抽象为寻找一个参数Θ,使得J(Θ) 最小。因为目前没有一个通用的方法可以对任意损失函数直接求解最佳的参数取值,所以在实践中,梯度下降算法是最常用的神经网络优化方法。梯度下降算法会迭代式更新参数Θ,不断沿着梯度的反方向让参数朝着总损失更小的方向更新。图1 展示了梯度下降算法的原理。

TensorFlow实战系列5--梯度下降算法

     图1 梯度下降算法思想示意图

 图1 中x 轴表示参数Θ 的取值,y 轴表示损失函数J(Θ) 的值。图1 的曲线表示了在参数Θ 取不同值时,对应损失函数J(Θ) 的大小。假设当前的参数和损失值对应图1 中小圆点的位置,那么梯度下降算法会将参数向x轴左侧移动,从而使得小圆点朝着箭头的方向移动。参数的梯度可以通过求偏导的方式计算,对于参数Θ,其梯度为∂/∂Θ J(Θ)。有了梯度,还需要定义一个学习率η(learning rate)来定义每次参数更新的幅度。从直观上理解,可以认为学习率定义的就是每次参数移动的幅度。通过参数的梯度和学习率,参数更新的公式见下页。下面给出了一个具体的例子来说明梯度下降算法是如何工作的。

TensorFlow实战系列5--梯度下降算法

假设要通过梯度下降算法来优化参数x,使得损失函数J(x)=x2 的值尽量小。梯度下降算法的第一步需要随机产生一个参数x 的初始值,然后再通过梯度和学习率来更新参数x 的取值。在这个样例中,参数x 的梯度为=( ∂ J(x))/ ∂ x=2x,那么使用梯度下降算法每次对参数x 的更新公式为x_(n+1)=x_n-η_n。假设参数的初始值为5,学习率为0.3,那么这个优化过程可以总结为表1。

TensorFlow实战系列5--梯度下降算法

 从表1 中可以看出,经过5 次迭代之后,参数x 的值变成了0.0512,这个和参数最优值0 已经比较接近了。虽然这里给出的是一个非常简单的样例,但是神经网络的优化过程也是可以类推的。神经网络的优化过程可以分为两个阶段,第一个阶段先通过前向传播算法计算得到预测值,并将预测值和真实值做对比得出两者之间的差距。然后在第二个阶段通过反向传播算法计算损失函数对每一个参数的梯度,再根据梯度和学习率使用梯度下降算法更新每一个参数。本书将略去反向传播算法具体的实现方法和
数学证明,有兴趣的读者可以参考论文Learning representations by backpropagating errors。

 因为梯度下降算法要在全部训练数据上最小化损失,所以损失函数J(Θ) 是在所有训练数据上的损失和。这样在每一轮迭代中都需要计算在全部训练数据上的损失函数。在海量训练数据下,要计算所有训练数据的损失函数是非常消耗时间的。为了加速训练过程,可以使用随机梯度下降的算法(stochastic gradient descent)。这个算法优化的不是在全部训练数据上的损失函数,而是在每一轮迭代中,随机优化某一条训练数据上的损失函数。这样每一轮参数更新的速度就大大加快了。因为随机梯度下降算法每次优化的只是某一条数据上的损失函数,所以它的问题也非常明显:在某一条数据上损失函数更小并不代表在全部数据上损失函数更小,于是使用随机梯度下降优化得到的神经网络甚至可能无法达到局部最优。为了综合梯度下降算法和随机梯度下降算法的优缺点,在实际应用中一般采用这两个算法的折中——每次计算一小部分训练数据的损失函数。这一小部分数据被称之为一个batch。通过矩阵运算,每次在一个batch上优化神经网络的参数并不会比单个数据慢太多。另一方面,每次使用一个batch 可以大大减小收敛所需要的迭代次数,同时可以使收敛到的结果更加接近梯度下降的效果。










推荐阅读
  • 浏览器中的异常检测算法及其在深度学习中的应用
    本文介绍了在浏览器中进行异常检测的算法,包括统计学方法和机器学习方法,并探讨了异常检测在深度学习中的应用。异常检测在金融领域的信用卡欺诈、企业安全领域的非法入侵、IT运维中的设备维护时间点预测等方面具有广泛的应用。通过使用TensorFlow.js进行异常检测,可以实现对单变量和多变量异常的检测。统计学方法通过估计数据的分布概率来计算数据点的异常概率,而机器学习方法则通过训练数据来建立异常检测模型。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
  • 背景应用安全领域,各类攻击长久以来都危害着互联网上的应用,在web应用安全风险中,各类注入、跨站等攻击仍然占据着较前的位置。WAF(Web应用防火墙)正是为防御和阻断这类攻击而存在 ... [详细]
  • 建立分类感知器二元模型对样本数据进行分类
    本文介绍了建立分类感知器二元模型对样本数据进行分类的方法。通过建立线性模型,使用最小二乘、Logistic回归等方法进行建模,考虑到可能性的大小等因素。通过极大似然估计求得分类器的参数,使用牛顿-拉菲森迭代方法求解方程组。同时介绍了梯度上升算法和牛顿迭代的收敛速度比较。最后给出了公式法和logistic regression的实现示例。 ... [详细]
  • cs231n Lecture 3 线性分类笔记(一)
    内容列表线性分类器简介线性评分函数阐明线性分类器损失函数多类SVMSoftmax分类器SVM和Softmax的比较基于Web的可交互线性分类器原型小结注:中文翻译 ... [详细]
  • 【论文】ICLR 2020 九篇满分论文!!!
    点击上方,选择星标或置顶,每天给你送干货!阅读大概需要11分钟跟随小博主,每天进步一丢丢来自:深度学习技术前沿 ... [详细]
  • OCR:用字符识别方法将形状翻译成计算机文字的过程Matlab:商业数学软件;CUDA:CUDA™是一种由NVIDIA推 ... [详细]
  • 3年半巨亏242亿!商汤高估了深度学习,下错了棋?
    转自:新智元三年半研发开支近70亿,累计亏损242亿。AI这门生意好像越来越不好做了。近日,商汤科技已向港交所递交IPO申请。招股书显示& ... [详细]
  • 人工智能推理能力与假设检验
    最近Google的Deepmind开始研究如何让AI做数学题。这个问题的提出非常有启发,逻辑推理,发现新知识的能力应该是强人工智能出现自我意识之前最需要发展的能力。深度学习目前可以 ... [详细]
  • 2017亚马逊人工智能奖公布:他们的AI有什么不同?
    事实上,在我们周围,“人工智能”让一切都变得更“智能”极具讽刺意味。随着人类与机器智能之间的界限变得模糊,我们的世界正在变成一个机器 ... [详细]
  • 阿里Treebased Deep Match(TDM) 学习笔记及技术发展回顾
    本文介绍了阿里Treebased Deep Match(TDM)的学习笔记,同时回顾了工业界技术发展的几代演进。从基于统计的启发式规则方法到基于内积模型的向量检索方法,再到引入复杂深度学习模型的下一代匹配技术。文章详细解释了基于统计的启发式规则方法和基于内积模型的向量检索方法的原理和应用,并介绍了TDM的背景和优势。最后,文章提到了向量距离和基于向量聚类的索引结构对于加速匹配效率的作用。本文对于理解TDM的学习过程和了解匹配技术的发展具有重要意义。 ... [详细]
  • Learning to Paint with Model-based Deep Reinforcement Learning
    本文介绍了一种基于模型的深度强化学习方法,通过结合神经渲染器,教机器像人类画家一样进行绘画。该方法能够生成笔画的坐标点、半径、透明度、颜色值等,以生成类似于给定目标图像的绘画。文章还讨论了该方法面临的挑战,包括绘制纹理丰富的图像等。通过对比实验的结果,作者证明了基于模型的深度强化学习方法相对于基于模型的DDPG和模型无关的DDPG方法的优势。该研究对于深度强化学习在绘画领域的应用具有重要意义。 ... [详细]
  • 本博文基于《Amalgamationofproteinsequence,structureandtextualinformationforimprovingprote ... [详细]
  • Opencv提供了几种分类器,例程里通过字符识别来进行说明的1、支持向量机(SVM):给定训练样本,支持向量机建立一个超平面作为决策平面,使得正例和反例之间的隔离边缘被最大化。函数原型:训练原型cv ... [详细]
author-avatar
一生一世0521
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有