热门标签 | HotTags
当前位置:  开发笔记 > 后端 > 正文

BERT的初始标准差0.02以及Warmup、LN的作用

前几天在群里大家讨论到了“Transformer如何解决梯度消失”这个问题,答案有提到残差的,也有提到LN(LayerNorm)的。这些是否都是正确答案呢?事实上这是一个非常有趣而

前几天在群里大家讨论到了“Transformer如何解决梯度消失”这个问题,答案有提到残差的,也有提到LN(Layer Norm)的。这些是否都是正确答案呢?事实上这是一个非常有趣而综合的问题,它其实关联到挺多模型细节,比如“BERT为什么要warmup?”、“BERT的初始化标准差为什么是0.02?”、“BERT做MLM预测之前为什么还要多加一层Dense?”,等等。本文就来集中讨论一下这些问题。


梯度消失说的是什么意思? #

在文章《也来谈谈RNN的梯度消失/爆炸问题》中,我们曾讨论过RNN的梯度消失问题。事实上,一般模型的梯度消失现象也是类似,它指的是(主要是在模型的初始阶段)越靠近输入的层梯度越小,趋于零甚至等于零,而我们主要用的是基于梯度的优化器,所以梯度消失意味着我们没有很好的信号去调整优化前面的层。

换句话说,前面的层也许几乎没有得到更新,一直保持随机初始化的状态;只有比较靠近输出的层才更新得比较好,但这些层的输入是前面没有更新好的层的输出,所以输入质量可能会很糟糕(因为经过了一个近乎随机的变换),因此哪怕后面的层更新好了,总体效果也不好。最终,我们会观察到很反直觉的现象:模型越深,效果越差,哪怕训练集都如此。

解决梯度消失的一个标准方法就是残差链接,正式提出于ResNet中。残差的思想非常简单直接:你不是担心输入的梯度会消失吗?那我直接给它补上一个梯度为常数的项不就行了?最简单地,将模型变成

这样一来,由于多了一条“直通”路x">xx,就算F(x)">F(x)F(x)中的x">xx梯度消失了,x">xx的梯度基本上也能得以保留,从而使得深层模型得到有效的训练。

 


LN真的能缓解梯度消失? #

然而,在BERT和最初的Transformer里边,使用的是Post Norm设计,它把Norm操作加在了残差之后:

其实具体的Norm方法不大重要,不管是Batch Norm还是Layer Norm,结论都类似。在文章《浅谈Transformer的初始化、参数化与标准化》中,我们已经分析过这种Norm结构,这里再来重复一下。

 

在初始化阶段,由于所有参数都是随机初始化的,所以我们可以认为x">xxF(x)">F(x)F(x)是两个相互独立的随机向量,如果假设它们各自的方差是1,那么x+F(x)">x+F(x)x+F(x)的方差就是2,而Norm">NormNorm操作负责将方差重新变为1,那么在初始化阶段,Norm">NormNorm操作就相当于“除以√2">2”:

递归下去就是

我们知道,残差有利于解决梯度消失,但是在Post Norm中,残差这条通道被严重削弱了,越靠近输入,削弱得越严重,残差“名存实亡”。所以说,在Post Norm的BERT模型中,LN不仅不能缓解梯度消失,它还是梯度消失的“元凶”之一。

 


那我们为什么还要加LN? #

那么,问题自然就来了:既然LN还加剧了梯度消失,那直接去掉它不好吗?

是可以去掉,但是前面说了,x+F(x)">x+F(x)x+F(x)的方差就是2了,残差越多方差就越大了,所以还是要加一个Norm操作,我们可以把它加到每个模块的输入,即变为x+F(Norm(x))">x+F(Norm(x))x+F(Norm(x)),最后的总输出再加个Norm">NormNorm就行,这就是Pre Norm结构,这时候每个残差分支是平权的,而不是像Post Norm那样有指数衰减趋势。当然,也有完全不加Norm的,但需要对F(x)">F(x)F(x)进行特殊的初始化,让它初始输出更接近于0,比如ReZero、Skip Init、Fixup等,这些在《浅谈Transformer的初始化、参数化与标准化》也都已经介绍过了。

但是,抛开这些改进不说,Post Norm就没有可取之处吗?难道Transformer和BERT开始就带了一个完全失败的设计?

显然不大可能。虽然Post Norm会带来一定的梯度消失问题,但其实它也有其他方面的好处。最明显的是,它稳定了前向传播的数值,并且保持了每个模块的一致性。比如BERT base,我们可以在最后一层接一个Dense来分类,也可以取第6层接一个Dense来分类;但如果你是Pre Norm的话,取出中间层之后,你需要自己接一个LN然后再接Dense,否则越靠后的层方差越大,不利于优化。

其次,梯度消失也不全是“坏处”,其实对于Finetune阶段来说,它反而是好处。在Finetune的时候,我们通常希望优先调整靠近输出层的参数,不要过度调整靠近输入层的参数,以免严重破坏预训练效果。而梯度消失意味着越靠近输入层,其结果对最终输出的影响越弱,这正好是Finetune时所希望的。所以,预训练好的Post Norm模型,往往比Pre Norm模型有更好的Finetune效果,这我们在《RealFormer:把残差转移到Attention矩阵上面去》也提到过。


我们真的担心梯度消失吗? #

其实,最关键的原因是,在当前的各种自适应优化技术下,我们已经不大担心梯度消失问题了。

这是因为,当前NLP中主流的优化器是Adam及其变种。对于Adam来说,由于包含了动量和二阶矩校正,所以近似来看,它的更新量大致上为

可以看到,分子分母是都是同量纲的,因此分式结果其实就是O(1)">O(1)O(1)的量级,而更新量就是O(η)">O(η)O(η)量级。也就是说,理论上只要梯度的绝对值大于随机误差,那么对应的参数都会有常数量级的更新量;这跟SGD不一样,SGD的更新量是正比于梯度的,只要梯度小,更新量也会很小,如果梯度过小,那么参数几乎会没被更新。

 

所以,Post Norm的残差虽然被严重削弱,但是在base、large级别的模型中,它还不至于削弱到小于随机误差的地步,因此配合Adam等优化器,它还是可以得到有效更新的,也就有可能成功训练了。当然,只是有可能,事实上越深的Post Norm模型确实越难训练,比如要仔细调节学习率和Warmup等。


Warmup是怎样起作用的? #

大家可能已经听说过,Warmup是Transformer训练的关键步骤,没有它可能不收敛,或者收敛到比较糟糕的位置。为什么会这样呢?不是说有了Adam就不怕梯度消失了吗?

要注意的是,Adam解决的是梯度消失带来的参数更新量过小问题,也就是说,不管梯度消失与否,更新量都不会过小。但对于Post Norm结构的模型来说,梯度消失依然存在,只不过它的意义变了。根据泰勒展开式:

也就是说增量f(x+Δx)−f(x)">f(x+Δx)f(x)f(x+Δx)−f(x)是正比于梯度的,换句话说,梯度衡量了输出对输入的依赖程度。如果梯度消失,那么意味着模型的输出对输入的依赖变弱了。

 

Warmup是在训练开始阶段,将学习率从0缓增到指定大小,而不是一开始从指定大小训练。如果不进行Wamrup,那么模型一开始就快速地学习,由于梯度消失,模型对越靠后的层越敏感,也就是越靠后的层学习得越快,然后后面的层是以前面的层的输出为输入的,前面的层根本就没学好,所以后面的层虽然学得快,但却是建立在糟糕的输入基础上的。

很快地,后面的层以糟糕的输入为基础到达了一个糟糕的局部最优点,此时它的学习开始放缓(因为已经到达了它认为的最优点附近),同时反向传播给前面层的梯度信号进一步变弱,这就导致了前面的层的梯度变得不准。但我们说过,Adam的更新量是常数量级的,梯度不准,但更新量依然是数量级,意味着可能就是一个常数量级的随机噪声了,于是学习方向开始不合理,前面的输出开始崩盘,导致后面的层也一并崩盘。

所以,如果Post Norm结构的模型不进行Wamrup,我们能观察到的现象往往是:loss快速收敛到一个常数附近,然后再训练一段时间,loss开始发散,直至NAN。如果进行Wamrup,那么留给模型足够多的时间进行“预热”,在这个过程中,主要是抑制了后面的层的学习速度,并且给了前面的层更多的优化时间,以促进每个层的同步优化。

这里的讨论前提是梯度消失,如果是Pre Norm之类的结果,没有明显的梯度消失现象,那么不加Warmup往往也可以成功训练。


初始标准差为什么是0.02? #

喜欢扣细节的同学会留意到,BERT默认的初始化方法是标准差为0.02的截断正态分布,在《浅谈Transformer的初始化、参数化与标准化》我们也提过,由于是截断正态分布,所以实际标准差会更小,大约是0.02/1.1368472≈0.0176">0.02/1.13684720.01760.02/1.1368472≈0.0176。这个标准差是大还是小呢?对于Xavier初始化来说,一个n×n">n×nn×n的矩阵应该用1/n">1/n1/n的方差初始化,而BERT base的n">nn为768,算出来的标准差是1/768≈0.0361">1/√7680.0361。这就意味着,这个初始化标准差是明显偏小的,大约只有常见初始化标准差的一半。

为什么BERT要用偏小的标准差初始化呢?事实上,这还是跟Post Norm设计有关,偏小的标准差会导致函数的输出整体偏小,从而使得Post Norm设计在初始化阶段更接近于恒等函数,从而更利于优化。具体来说,按照前面的假设,如果x">xx的方差是1,F(x)">F(x)F(x)的方差是σ2">σ2σ2,那么初始化阶段,Norm">NormNorm操作就相当于除以√(1+σ2">1+σ2)。如果σ">σσ比较小,那么残差中的“直路”权重就越接近于1,那么模型初始阶段就越接近一个恒等函数,就越不容易梯度消失。

正所谓“我们不怕梯度消失,但我们也不希望梯度消失”,简单地将初始化标注差设小一点,就可以使得σ">σσ变小一点,从而在保持Post Norm的同时缓解一下梯度消失,何乐而不为?那能不能设置得更小甚至全零?一般来说初始化过小会丧失多样性,缩小了模型的试错空间,也会带来负面效果。综合来看,缩小到标准的1/2,是一个比较靠谱的选择了。

当然,也确实有人喜欢挑战极限的,最近笔者也看到了一篇文章,试图让整个模型用几乎全零的初始化,还训练出了不错的效果,大家有兴趣可以读读,文章为《ZerO Initialization: Initializing Residual Networks with only Zeros and Ones》。


为什么MLM要多加Dense? #

最后,是关于BERT的MLM模型的一个细节,就是BERT在做MLM的概率预测之前,还要多接一个Dense层和LN层,这是为什么呢?不接不行吗?

之前看到过的答案大致上是觉得,越靠近输出层的,越是依赖任务的(Task-Specified),我们多接一个Dense层,希望这个Dense层是MLM-Specified的,然后下游任务微调的时候就不是MLM-Specified的,所以把它去掉。这个解释看上去有点合理,但总感觉有点玄学,毕竟Task-Specified这种东西不大好定量分析。

这里笔者给出另外一个更具体的解释,事实上它还是跟BERT用了0.02的标准差初始化直接相关。刚才我们说了,这个初始化是偏小的,如果我们不额外加Dense就乘上Embedding预测概率分布,那么得到的分布就过于均匀了(Softmax之前,每个logit都接近于0),于是模型就想着要把数值放大。现在模型有两个选择:第一,放大Embedding层的数值,但是Embedding层的更新是稀疏的,一个个放大太麻烦;第二,就是放大输入,我们知道BERT编码器最后一层是LN,LN最后有个初始化为1的gamma参数,直接将那个参数放大就好。

模型优化使用的是梯度下降,我们知道它会选择最快的路径,显然是第二个选择更快,所以模型会优先走第二条路。这就导致了一个现象:最后一个LN层的gamma值会偏大。如果预测MLM概率分布之前不加一个Dense+LN,那么BERT编码器的最后一层的LN的gamma值会偏大,导致最后一层的方差会比其他层的明显大,显然不够优雅;而多加了一个Dense+LN后,偏大的gamma就转移到了新增的LN上去了,而编码器的每一层则保持了一致性。

事实上,读者可以自己去观察一下BERT每个LN层的gamma值,就会发现确实是最后一个LN层的gamma值是会明显偏大的,这就验证了我们的猜测~


希望大家多多海涵批评斧正 #

本文试图回答了Transformer、BERT的模型优化相关的几个问题,有一些是笔者在自己的预训练工作中发现的结果,有一些则是结合自己的经验所做的直观想象。不管怎样,算是分享一个参考答案吧,如果有不当的地方,请大家海涵,也请各位批评斧正~

 

来自

https://kexue.fm/archives/8747

 


欢迎转载,转载请保留页面地址。帮助到你的请点个推荐。



推荐阅读
  • 深度学习中的Vision Transformer (ViT)详解
    本文详细介绍了深度学习中的Vision Transformer (ViT)方法。首先介绍了相关工作和ViT的基本原理,包括图像块嵌入、可学习的嵌入、位置嵌入和Transformer编码器等。接着讨论了ViT的张量维度变化、归纳偏置与混合架构、微调及更高分辨率等方面。最后给出了实验结果和相关代码的链接。本文的研究表明,对于CV任务,直接应用纯Transformer架构于图像块序列是可行的,无需依赖于卷积网络。 ... [详细]
  • 在Kubernetes上部署JupyterHub的步骤和实验依赖
    本文介绍了在Kubernetes上部署JupyterHub的步骤和实验所需的依赖,包括安装Docker和K8s,使用kubeadm进行安装,以及更新下载的镜像等。 ... [详细]
  • 突破MIUI14限制,自定义胶囊图标、大图标样式,支持任意APP
    本文介绍了如何突破MIUI14的限制,实现自定义胶囊图标和大图标样式,并支持任意APP。需要一定的动手能力和主题设计师账号权限或者会主题pojie。详细步骤包括应用包名获取、素材制作和封包获取等。 ... [详细]
  • 合并列值-合并为一列问题需求:createtabletab(Aint,Bint,Cint)inserttabselect1,2,3unionallsel ... [详细]
  • GPT-3发布,动动手指就能自动生成代码的神器来了!
    近日,OpenAI发布了最新的NLP模型GPT-3,该模型在GitHub趋势榜上名列前茅。GPT-3使用的数据集容量达到45TB,参数个数高达1750亿,训练好的模型需要700G的硬盘空间来存储。一位开发者根据GPT-3模型上线了一个名为debuid的网站,用户只需用英语描述需求,前端代码就能自动生成。这个神奇的功能让许多程序员感到惊讶。去年,OpenAI在与世界冠军OG战队的表演赛中展示了他们的强化学习模型,在限定条件下以2:0完胜人类冠军。 ... [详细]
  • 本文讨论了如何使用GStreamer来删除H264格式视频文件中的中间部分,而不需要进行重编码。作者提出了使用gst_element_seek(...)函数来实现这个目标的思路,并提到遇到了一个解决不了的BUG。文章还列举了8个解决方案,希望能够得到更好的思路。 ... [详细]
  • 本文介绍了如何使用n3-charts绘制以日期为x轴的数据,并提供了相应的代码示例。通过设置x轴的类型为日期,可以实现对日期数据的正确显示和处理。同时,还介绍了如何设置y轴的类型和其他相关参数。通过本文的学习,读者可以掌握使用n3-charts绘制日期数据的方法。 ... [详细]
  • 本文讨论了在dva中引入antd组件table时没有显示样式的问题。提供了.roadhogrc文件的配置,包括环境和import的设置。同时介绍了extraBabelPlugins和transform-runtime的使用方法,并解释了libraryName和css的含义。 ... [详细]
  • 本文介绍了使用readlink命令获取文件的完整路径的简单方法,并提供了一个示例命令来打印文件的完整路径。共有28种解决方案可供选择。 ... [详细]
  • 本文整理了Java中org.apache.solr.common.SolrDocument.setField()方法的一些代码示例,展示了SolrDocum ... [详细]
  • 本文整理了常用的CSS属性及用法,包括背景属性、边框属性、尺寸属性、可伸缩框属性、字体属性和文本属性等,方便开发者查阅和使用。 ... [详细]
  • 读手语图像识别论文笔记2
    文章目录一、前言二、笔记1.名词解释2.流程分析上一篇快速门:读手语图像识别论文笔记1(手语识别背景和方法)一、前言一句:“做完了&#x ... [详细]
  • 本博文基于《Amalgamationofproteinsequence,structureandtextualinformationforimprovingprote ... [详细]
  • Silverlight 引路蜂二维图形库示例:线段连接类型(LineJoin)
    线段连接类型(LineJoin)指定了线段了连接的方式,有三种不同的连接类型JOIN_MITER,JOIN_ROUND和OIN_BEVEL。下面类型显示了三种不同的 ... [详细]
  • 关于如何快速定义自己的数据集,可以参考我的前一篇文章PyTorch中快速加载自定义数据(入门)_晨曦473的博客-CSDN博客刚开始学习P ... [详细]
author-avatar
手机用户2502862657
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有