热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

CodeBERT理解

1.动机大型的预训练模型,比如ELMo、GPT、Bert等提高了NLP任务的最新技术。这些预训练模型在NLP的成功驱动了多模态预训练模型,比如ViBE




1.动机

大型的预训练模型,比如ELMo、GPT、Bert等提高了NLP任务的最新技术。这些预训练模型在NLP的成功驱动了多模态预训练模型,比如ViBERT、VideoBERT(他们从双模式数据,比如语言-图像对中进行自监督学习)

CodeBERT,是一种用于编程语言(PL)和自然语言(NL)的bimodal预训练模型。CodeBERT捕获自然语言和编程语言的语义连接,生成能广泛支持NL-PL理解任务(自然语言代码搜索)和生成任务(代码文档生成)的通用表示形式。

为了利用NL-PL pairs bimodal 实例和大量可用的uni-modal 代码,训练CodeBERT时用一个混合目标函数(包含标准掩码语言模型和 replaced token detection)


replaced token detection

在BERT中,句子内15%的token被选中,其中80%被[MASK]替换,10%被随机替换,10%保持不变,随后将替换后的句子输入到BERT中用于预测那些被替换的token。

论文作者认为BERT只学习这15%的token有点浪费算力,还存在[MASK]不会在实际任务中出现的问题。于是,文章提出了一个新的预训练任务:replaced token detection,即首先使用一个生成器预测句中被mask掉的token,接下来使用预测的token替代句中的[MASK]标记,然后使用一个判别器区分句中的每个token是原始的还是替换后的。
在这里插入图片描述
在预训练后,将判别器用于用于下游任务。作者认为replaced token detection任务让模型(判别器)可以在所有的token上学习,而不是那些仅仅被mask掉的token,这使得计算效率更高。


2.1模型架构

遵循 BERT 和 RoBERTa ,并使用了多层双向 Transformer 作为 CodeBERT 的模型架构。通过使用与Roberta-Base完全相同的模型架构来开发Codebert。模型参数的总数为125M。

RoBERTa:
与BERT相比主要有以下几点改进:


  • 更大的模型参数量(论文提供的训练时间来看,模型使用 1024 块 V100 GPU 训练了 1 天的时间)
  • 更大bacth size。RoBERTa 在训练过程中使用了更大的bacth size。尝试过从 256 到 8000 不等的bacth size。
  • 更多的训练数据(包括:CC-NEWS 等在内的 160GB 纯文本。而最初的BERT使用16GB
    BookCorpus数据集和英语维基百科进行训练)

另外,RoBERTa在训练方法上有以下改进:


  • 去掉下一句预测(NSP)任务
  • 动态掩码。BERT 依赖随机掩码和预测 token。原版的 BERT 实现在数据预处理期间执行一次掩码,得到一个静态掩码。 而 RoBERTa 使用了动态掩码:每次向模型输入一个序列时都会生成新的掩码模式。这样,在大量数据不断输入的过程中,模型会逐渐适应不同的掩码策略,学习不同的语言表征。
  • 文本编码。Byte-Pair Encoding(BPE)是字符级和词级别表征的混合,支持处理自然语言语料库中的众多常见词汇。原版的 BERT 实现使用字符级别的 BPE 词汇,大小为 30K,是在利用启发式分词规则对输入进行预处理之后学得的。Facebook 研究者没有采用这种方式,而是考虑用更大的 byte 级别 BPE 词汇表来训练 BERT,这一词汇表包含 50K 的 subword 单元,且没有对输入作任何额外的预处理或分词。

2.2输入/输出表示

在预训练阶段,将输入设置为两个序列与特殊token的串联




[


C


L


S


]


,



w


1



,



w


1



,



w


2



,


.


.


.


,



w


n



,


[


S


E


P


]


,



c


1



,



c


2



,


.


.


.


,



c


m



,


[


E


O


S


]



[CLS] ,w_{1},w_{1},w_{2},...,w_{n},[SEP],c_{1},c_{2},...,c_{m},[EOS]


[CLS],w1,w1,w2,...,wn,[SEP],c1,c2,...,cm,[EOS]
,其中一个序列是自然语言文本,另一个是编程语言。




[


C


L


S


]



[CLS]


[CLS]
是两个部分前面的特殊token,其最终隐藏代表被视为分类或排名的汇总序列表示形式。

按照transformers中处理文本的标准方式,将自然语言文本看作一系列单词,把它分为WordPiece。把一块代码视为一系列tokens。
输出包括:


  1. 对于自然语言和代码的每个token的上下文向量表示
  2. [CLS]的表示,该表示充当汇总序列表示

2.3训练数据

训练CodeBERT模型并行使用单峰数据和双峰数据,双峰数据就是NL-PL对,单峰数据就是单一的自然语言文本,或者是单一的代码。
数据来自公开可用的开源非fork github存储库,并用一组约束和规则过滤。


  1. 每个项目至少被另一个项目使用过
  2. 每个文档都被截断为第一段
  3. 少于三个token的文档被删除
  4. 少于三行的函数被删除
  5. 函数名带有test的函数被删除

2.4预训练CodeBERT

两个训练目标:


  • 掩码语言模型(MLM),应用于NL-PL pairs的双峰数据
  • 替换token检测(RTD),应用于单峰数据

MLM

给定NL-PL pair(




X


=



w


,


c




X = {w,c}


X=w,c
)的数据点作为输入,其中




w



w


w
是NL单​​词的序列,




c



c


c
是PL tokens的序列,我们首先选择NL和PL的随机位置集要掩盖(分别为





m


w




m^{w}


mw






m


c




m^{c}


mc
),然后用特殊的[MASK] token替换所选位置。其中x中15%的token被mask掉。然后预测被mask掉的原始token。


RTD

在我们的情况下对其进行调整,并在训练中同时使用双峰和单峰数据。具体而言,有两个数据生成器,一个NL,一个PL,均用于为一组随机掩盖的位置生成合理的替代方案。

判别器被训练检测一个单词是否是原始的单词,这是一个二分类问题。值得注意的是,RTD适用于输入的每个位置,如果生成器生成的单词恰好是原来真实的单词,则标签为 real 而不是 fake

在这里插入图片描述
NL和Code生成器都是语言模型,它们基于周围的上下文环境生成了被掩盖位置的合理的token,NL-Code Discriminator是目标的预训练模型,该模型通过检测从NL和PL生成器采样的合理的可替代方案来训练。NL-Code 判别器用于在微调阶段产生通用的表示。NL和Code生成器在微调阶段被抛出。


2.5微调CodeBERT

有不同的设置可以在下游NL-PL任务中使用Codebert。例如,在自然语言代码搜索中,我们将输入与预训练阶段相同,并使用[CLS]的表示来测量代码和自然语言查询之间的语义相关性,而在代码到文本中生成,我们使用编码 - 解码器框架,并使用Codebert初始化生成模型的编码器。


3源码分析

从头开始预训练RoBERTa模型的步骤:


  1. 利用Tokenizer对语料分词
  2. 配置RoBERTa模型
  3. 训练任务的代码
  4. 数据输入模型进行训练

3.1对语料进行分词——RobertaTokenizer

核心代码在\transformers\models\roberta\tokenization_roberta.py


3.2配置RoBERTa模型

核心代码在\transformers\models\roberta\modeling_roberta.py


3.3训练任务的代码

MLM任务:

class RobertaLMHead(nn.Module):
"""Roberta Head for masked language modeling."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias
def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = self.decoder(x)
return x

  • 对输入的embedding进行线性变换、激活和层归一化
  • 输出形状为[batch_size, seq_length, vocab_size],即预测每个句子每个词是什么类别的概率值

3.4数据输入模型进行训练

class RobertaForMaskedLM(RobertaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
if config.is_decoder:
logger.warning(
"If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.lm_head = RobertaLMHead(config)
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
mask="",
expected_output="' Paris'",
expected_loss=0.1,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
Used to hide legacy arguments that have been deprecated.
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)






推荐阅读
  • 本文介绍了设计师伊振华受邀参与沈阳市智慧城市运行管理中心项目的整体设计,并以数字赋能和创新驱动高质量发展的理念,建设了集成、智慧、高效的一体化城市综合管理平台,促进了城市的数字化转型。该中心被称为当代城市的智能心脏,为沈阳市的智慧城市建设做出了重要贡献。 ... [详细]
  • 自动轮播,反转播放的ViewPagerAdapter的使用方法和效果展示
    本文介绍了如何使用自动轮播、反转播放的ViewPagerAdapter,并展示了其效果。该ViewPagerAdapter支持无限循环、触摸暂停、切换缩放等功能。同时提供了使用GIF.gif的示例和github地址。通过LoopFragmentPagerAdapter类的getActualCount、getActualItem和getActualPagerTitle方法可以实现自定义的循环效果和标题展示。 ... [详细]
  • CF:3D City Model(小思维)问题解析和代码实现
    本文通过解析CF:3D City Model问题,介绍了问题的背景和要求,并给出了相应的代码实现。该问题涉及到在一个矩形的网格上建造城市的情景,每个网格单元可以作为建筑的基础,建筑由多个立方体叠加而成。文章详细讲解了问题的解决思路,并给出了相应的代码实现供读者参考。 ... [详细]
  • 深度学习中的Vision Transformer (ViT)详解
    本文详细介绍了深度学习中的Vision Transformer (ViT)方法。首先介绍了相关工作和ViT的基本原理,包括图像块嵌入、可学习的嵌入、位置嵌入和Transformer编码器等。接着讨论了ViT的张量维度变化、归纳偏置与混合架构、微调及更高分辨率等方面。最后给出了实验结果和相关代码的链接。本文的研究表明,对于CV任务,直接应用纯Transformer架构于图像块序列是可行的,无需依赖于卷积网络。 ... [详细]
  • 云原生边缘计算之KubeEdge简介及功能特点
    本文介绍了云原生边缘计算中的KubeEdge系统,该系统是一个开源系统,用于将容器化应用程序编排功能扩展到Edge的主机。它基于Kubernetes构建,并为网络应用程序提供基础架构支持。同时,KubeEdge具有离线模式、基于Kubernetes的节点、群集、应用程序和设备管理、资源优化等特点。此外,KubeEdge还支持跨平台工作,在私有、公共和混合云中都可以运行。同时,KubeEdge还提供数据管理和数据分析管道引擎的支持。最后,本文还介绍了KubeEdge系统生成证书的方法。 ... [详细]
  • http:my.oschina.netleejun2005blog136820刚看到群里又有同学在说HTTP协议下的Get请求参数长度是有大小限制的,最大不能超过XX ... [详细]
  • 本文介绍了Web学习历程记录中关于Tomcat的基本概念和配置。首先解释了Web静态Web资源和动态Web资源的概念,以及C/S架构和B/S架构的区别。然后介绍了常见的Web服务器,包括Weblogic、WebSphere和Tomcat。接着详细讲解了Tomcat的虚拟主机、web应用和虚拟路径映射的概念和配置过程。最后简要介绍了http协议的作用。本文内容详实,适合初学者了解Tomcat的基础知识。 ... [详细]
  • 解决VS写C#项目导入MySQL数据源报错“You have a usable connection already”问题的正确方法
    本文介绍了在VS写C#项目导入MySQL数据源时出现报错“You have a usable connection already”的问题,并给出了正确的解决方法。详细描述了问题的出现情况和报错信息,并提供了解决该问题的步骤和注意事项。 ... [详细]
  • 本文介绍了Android 7的学习笔记总结,包括最新的移动架构视频、大厂安卓面试真题和项目实战源码讲义。同时还分享了开源的完整内容,并提醒读者在使用FileProvider适配时要注意不同模块的AndroidManfiest.xml中配置的xml文件名必须不同,否则会出现问题。 ... [详细]
  • web.py开发web 第八章 Formalchemy 服务端验证方法
    本文介绍了在web.py开发中使用Formalchemy进行服务端表单数据验证的方法。以User表单为例,详细说明了对各字段的验证要求,包括必填、长度限制、唯一性等。同时介绍了如何自定义验证方法来实现验证唯一性和两个密码是否相等的功能。该文提供了相关代码示例。 ... [详细]
  • 突破MIUI14限制,自定义胶囊图标、大图标样式,支持任意APP
    本文介绍了如何突破MIUI14的限制,实现自定义胶囊图标和大图标样式,并支持任意APP。需要一定的动手能力和主题设计师账号权限或者会主题pojie。详细步骤包括应用包名获取、素材制作和封包获取等。 ... [详细]
  • Spring常用注解(绝对经典),全靠这份Java知识点PDF大全
    本文介绍了Spring常用注解和注入bean的注解,包括@Bean、@Autowired、@Inject等,同时提供了一个Java知识点PDF大全的资源链接。其中详细介绍了ColorFactoryBean的使用,以及@Autowired和@Inject的区别和用法。此外,还提到了@Required属性的配置和使用。 ... [详细]
  • 本文记录了在vue cli 3.x中移除console的一些采坑经验,通过使用uglifyjs-webpack-plugin插件,在vue.config.js中进行相关配置,包括设置minimizer、UglifyJsPlugin和compress等参数,最终成功移除了console。同时,还包括了一些可能出现的报错情况和解决方法。 ... [详细]
  • 本文介绍了Swing组件的用法,重点讲解了图标接口的定义和创建方法。图标接口用来将图标与各种组件相关联,可以是简单的绘画或使用磁盘上的GIF格式图像。文章详细介绍了图标接口的属性和绘制方法,并给出了一个菱形图标的实现示例。该示例可以配置图标的尺寸、颜色和填充状态。 ... [详细]
  • 本文讨论了如何使用Web.Config进行自定义配置节的配置转换。作者提到,他将msbuild设置为详细模式,但转换却忽略了带有替换转换的自定义部分的存在。 ... [详细]
author-avatar
mobiledu2502856013
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有