热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

PatternExploitingTrainingMLM任务用于文本匹配【代码解读】

一、总结•原文:#PET-文本分类的又一种妙解:https:xv44586.github.io20201025pet#ccf问答匹配比赛(下

一、总结
• 原文:

# PET-文本分类的又一种妙解:https://xv44586.github.io/2020/10/25/pet/
# ccf问答匹配比赛(下):如何只用“bert”夺冠:https://xv44586.github.io/2021/01/20/ccf-qa-2/

在这里插入图片描述


三、代码注释

原始链接:https://github.com/xv44586/ccf_2020_qa_match

# -*- coding: utf-8 -*-
# @Date : 2020/11/4
# @Author : mingming.xu
# @Email : xv44586@gmail.com
# @File : ccf_2020_qa_match_pet.py
"""
Pattern-Exploiting Training(PET): 增加pattern,将任务转换为MLM任务。
线上f1: 0.761tips:切换模型时,修改对应config_path/checkpoint_path/dict_path路径以及build_transformer_model 内的参数
"""
import os
import numpy as np
import json
from tqdm import tqdm
import numpy as np
import pandas as pd
from toolkit4nlp.backend import keras, K
from toolkit4nlp.tokenizers import Tokenizer, load_vocab
from toolkit4nlp.models import build_transformer_model, Model
from toolkit4nlp.optimizers import *
from toolkit4nlp.utils import pad_sequences, DataGenerator
from toolkit4nlp.layers import *os.environ["CUDA_VISIBLE_DEVICES"] = "1"# PET-文本分类的又一种妙解:https://xv44586.github.io/2020/10/25/pet/
# ccf问答匹配比赛(下):如何只用“bert”夺冠:https://xv44586.github.io/2021/01/20/ccf-qa-2/
num_classes = 32
maxlen = 128
batch_size = 8# BERT baseconfig_path = 'data/pretrained/nezha/NEZHA-Base/bert_config.json'
checkpoint_path = 'data/pretrained/nezha/NEZHA-Base/model.ckpt-900000'
dict_path = 'data/pretrained/nezha/NEZHA-Base/vocab.txt'tokenizer = Tokenizer(dict_path, do_lower_case=True)# pattern
pattern = '下面两个句子的语义相似度较高:'
# tokenizer.encode的第一个位置是cls,所以mask的index要+1
tokens = ["CLS"]+list(pattern)
print(tokens[14])
mask_idx = [14]id2label = {0: '低',1: '高'
}label2id = {v: k for k, v in id2label.items()}
print('label2id:',label2id)#label2id: {'低': 0, '高': 1}
labels = list(id2label.values())
print('labels:',labels)#labels: ['低', '高']
# labels在token中的ids,encode的时候,第一个数是cls,所以取encode输出的tokens[1:-1],代表跳过了cls的
label_ids = np.array([tokenizer.encode(l)[0][1:-1] for l in labels])
print('label_ids:',label_ids)#label_ids: [[ 856] [7770]]# 这里本文其实没有用到
def random_masking(token_ids):"""对输入进行随机mask"""# n个随机数rands &#61; np.random.random(len(token_ids))source, target &#61; [], []for r, t in zip(rands, token_ids):# [mask, 0.15 * 0.8, t(本身), 0.15 * 0.9, 随机, 0.15, 本身&#xff0c;target&#61;0&#xff0c;其余target都为1]if r < 0.15 * 0.8:# 通过mask来预测targetsource.append(tokenizer._token_mask_id)target.append(t)elif r < 0.15 * 0.9:# 通过本身来预测targetsource.append(t)target.append(t)elif r < 0.15:# 通过随机token来预测targetsource.append(np.random.choice(tokenizer._vocab_size - 1) &#43; 1)target.append(t)else:# 通过本身->label&#61;0?source.append(t)target.append(0)return source, targetclass data_generator(DataGenerator):def __init__(self, prefix&#61;False, *args, **kwargs):super(data_generator, self).__init__(*args, **kwargs)self.prefix &#61; prefixdef __iter__(self, shuffle&#61;False):batch_token_ids, batch_segment_ids, batch_target_ids &#61; [], [], []# 拿到query和replyfor is_end, (q, r, label) in self.get_sample(shuffle):# 没有label的时候定义为Nonelabel &#61; int(label) if label is not None else None# 有label的时候&#xff0c;才添加前缀if label is not None or self.prefix:q &#61; pattern &#43; q# 拿到token_ids和segment_idtoken_ids, segment_ids &#61; tokenizer.encode(q, r, maxlen&#61;maxlen)# 本文没有用到这个if shuffle:# 这里做了随机mask&#xff0c;随机mask有点没看懂, 但是本文都没用到这个source_tokens, target_tokens &#61; random_masking(token_ids)else:# 理论上target_tokens就等于source_tokenssource_tokens, target_tokens &#61; token_ids[:], token_ids[:]# mask labelif label is not None:# 将label转化成token&#xff0c;因为是mlm任务&#xff0c;最终的label其实就是tokenlabel_ids &#61; tokenizer.encode(id2label[label])[0][1:-1]# pattern &#61; &#39;直接回答问题:&#39;# mask_idx &#61; [1]# 这里label_ids也只有一个&#xff0c;所以是直接复制# mask_idx代表的其实是label在原文中的位置for m, lb in zip(mask_idx, label_ids):# 这里相当于把原文的label更换成为mask_id# source_tokens[1] &#61; mask_id# 然后target_tokens[1] &#61; label_id(也就是label对应的token_id)# 这里只更改了label对应的token&#xff0c;其余部分不变source_tokens[m] &#61; tokenizer._token_mask_idtarget_tokens[m] &#61; lbelif self.prefix:# 这里就一个mask_id&#xff0c;如果有多个多个都直接赋值成为token_idfor i in mask_idx:source_tokens[i] &#61; tokenizer._token_mask_id# 最后拿到mlm任务的source_tokens,segment_ids,target_tokensbatch_token_ids.append(source_tokens)batch_segment_ids.append(segment_ids)batch_target_ids.append(target_tokens)if is_end or len(batch_token_ids) &#61;&#61; self.batch_size:# 满足batch_size要求了&#xff0c;把他yield出去batch_token_ids &#61; pad_sequences(batch_token_ids)batch_segment_ids &#61; pad_sequences(batch_segment_ids)batch_target_ids &#61; pad_sequences(batch_target_ids)# batch_target_ids是每个位置target的idyield [batch_token_ids, batch_segment_ids, batch_target_ids], None# 将原始的batch里面的内容置为空batch_token_ids, batch_segment_ids, batch_target_ids &#61; [], [], []class CrossEntropy(Loss):"""交叉熵作为loss&#xff0c;并mask掉输入部分"""def compute_loss(self, inputs, mask&#61;None):y_true, y_pred &#61; inputs# K.not_equal, 拿到y_true不为0的部分&#xff0c;然后转化成为floaty_mask &#61; K.cast(K.not_equal(y_true, 0), K.floatx())# 计算精度accuracy &#61; keras.metrics.sparse_categorical_accuracy(y_true, y_pred)# mask掉输入部分accuracy &#61; K.sum(accuracy * y_mask) / K.sum(y_mask)# 拿到acc精度self.add_metric(accuracy, name&#61;&#39;accuracy&#39;)# 拿到交叉熵loss &#61; K.sparse_categorical_crossentropy(y_true, y_pred)# maskloss &#61; K.sum(loss * y_mask) / K.sum(y_mask)return loss# tokenizer
# tokenizer &#61; Tokenizer(dict_path, do_lower_case&#61;True)def train(train_data, val_data, test_data, best_model_file, test_result_file):train_generator &#61; data_generator(data&#61;train_data &#43; test_data, batch_size&#61;batch_size)valid_generator &#61; data_generator(data&#61;val_data, batch_size&#61;batch_size)test_generator &#61; data_generator(data&#61;test_data, batch_size&#61;batch_size, prefix&#61;True)target_in &#61; Input(shape&#61;(None,))model &#61; build_transformer_model(config_path&#61;config_path,checkpoint_path&#61;checkpoint_path,with_mlm&#61;True, # with_nlm为True是不是返回的output就不一样了&#xff0c;应该返回的就是mlm的output# model&#61;&#39;bert&#39;, # 加载bert/Roberta/erniemodel&#61;&#39;nezha&#39;)output &#61; CrossEntropy(output_idx&#61;1)([target_in, model.output])# 输入的时候&#xff0c;添加一个target_in&#xff0c; 输出还是和之前一样train_model &#61; Model(model.inputs &#43; [target_in], output)# 梯度衰减&#43;梯度积累AdamW &#61; extend_with_weight_decay(Adam)AdamWG &#61; extend_with_gradient_accumulation(AdamW)opt &#61; AdamWG(learning_rate&#61;1e-5, exclude_from_weight_decay&#61;[&#39;Norm&#39;, &#39;bias&#39;], grad_accum_steps&#61;4)train_model.compile(opt)train_model.summary()def evaluate(data):P, R, TP &#61; 0., 0., 0.for d, _ in tqdm(data):x_true, y_true &#61; d[:2], d[2]# 拿到预测结果&#xff0c;已经转化为label_ids里面的index了y_pred &#61; predict(x_true)# 只取mask_idx对应的y -> 原始token -> 原始label中的indexy_true &#61; np.array([labels.index(tokenizer.decode(y)) for y in y_true[:, mask_idx]])# print(y_true, y_pred)# 计算f1R &#43;&#61; y_pred.sum()P &#43;&#61; y_true.sum()TP &#43;&#61; ((y_pred &#43; y_true) > 1).sum()print(P, R, TP)pre &#61; TP / Rrec &#61; TP / Preturn 2 * (pre * rec) / (pre &#43; rec)def predict(x):if len(x) &#61;&#61; 3:x &#61; x[:2]# 拿到mask_idx对应的output# todo:这里这个model为什么不是train_model啊?y_pred &#61; model.predict(x)[:, mask_idx]# 这个维度信息不太清楚# batch, 0,label_ids对应的值, label_ids应该是可能有多个id&#xff0c;对应分类的多个类别y_pred &#61; y_pred[:, 0, label_ids[:, 0]]# 最后是取得所有label_ids里面的最大值&#xff0c;得到mlm的预测结果的&#xff0c;这里面的mlm的预测的结果的个数与分类的label数一致y_pred &#61; y_pred.argmax(axis&#61;1)return y_predclass Evaluator(keras.callbacks.Callback):def __init__(self, valid_generator, best_pet_model_file&#61;"best_pet_model.weights"):self.best_acc &#61; 0.self.valid_generator &#61; valid_generatorself.best_pet_model_file &#61; best_pet_model_filedef on_epoch_end(self, epoch, logs&#61;None):acc &#61; evaluate(self.valid_generator)if acc > self.best_acc:self.best_acc &#61; accself.model.save_weights(self.best_pet_model_file)print(&#39;acc :{}, best acc:{}&#39;.format(acc, self.best_acc))def write_to_file(path, test_generator, test_data):preds &#61; []# 分批预测结果for x, _ in tqdm(test_generator):pred &#61; predict(x)preds.extend(pred)# 把原始的query&#xff0c;reply以及预测的p都写入到文件中ret &#61; []for data, p in zip(test_data, preds):if data[2] is None:label &#61; -1else:label &#61; data[2]ret.append([data[0], data[1], str(label), str(p)])with open(path, &#39;w&#39;) as f:for r in ret:f.write(&#39;\t&#39;.join(r) &#43; &#39;\n&#39;)evaluator &#61; Evaluator(valid_generator, best_model_file)train_model.fit_generator(train_generator.generator(),steps_per_epoch&#61;len(train_generator),epochs&#61;10,callbacks&#61;[evaluator])train_model.load_weights(best_model_file)write_to_file(test_result_file, test_generator, test_data)def load_pair_data(f, isshuffle&#61;False):data &#61; []df &#61; pd.read_csv(f)if isshuffle:df &#61; df.sample(frac&#61;1.0, random_state&#61;1234)columns &#61; list(df.columns)if &#39;text_a&#39; not in columns and &#39;query1&#39; in columns:df.rename(columns&#61;{&#39;query1&#39;:&#39;text_a&#39;, &#39;query2&#39;:&#39;text_b&#39;}, inplace&#61;True)for i in range(len(df)):can &#61; df.iloc[i]text_a &#61; can[&#39;text_a&#39;]text_b &#61; can[&#39;text_b&#39;]if &#39;label&#39; not in columns:label &#61; Noneelse:label &#61; int(can[&#39;label&#39;])if label &#61;&#61; -1:label &#61; Nonedata.append([text_a, text_b, label])return datadef load_data():""":return: [text_a, text_b, label]天池疫情文本匹配数据集"""data_dir &#61; &#39;../data/tianchi/&#39;train_file &#61; data_dir &#43; &#39;train_20200228.csv&#39;dev_file &#61; data_dir &#43; &#39;dev_20200228.csv&#39;test_file &#61; data_dir &#43; &#39;test.example_20200228.csv&#39;train_data &#61; load_pair_data(train_file)val_data &#61; load_pair_data(dev_file)test_data &#61; load_pair_data(test_file)return train_data, val_data, test_datadef test_data_generator():data_dir &#61; &#39;../data/tianchi/&#39;train_file &#61; data_dir &#43; &#39;train_20200228.csv&#39;data &#61; load_pair_data(train_file)train_generator &#61; data_generator(data&#61;data, batch_size&#61;batch_size)for d in train_generator:print(d)breakdef run():train_data, val_data, test_data &#61; load_data()best_model_file &#61; &#39;best_pet_model.weights&#39;test_result_file &#61; &#39;pet_submission.tsv&#39;train(train_data, val_data, test_data, best_model_file, test_result_file)if __name__ &#61;&#61; &#39;__main__&#39;:test_data_generator()run()

推荐阅读
  • XML介绍与使用的概述及标签规则
    本文介绍了XML的基本概念和用途,包括XML的可扩展性和标签的自定义特性。同时还详细解释了XML标签的规则,包括标签的尖括号和合法标识符的组成,标签必须成对出现的原则以及特殊标签的使用方法。通过本文的阅读,读者可以对XML的基本知识有一个全面的了解。 ... [详细]
  • Python正则表达式学习记录及常用方法
    本文记录了学习Python正则表达式的过程,介绍了re模块的常用方法re.search,并解释了rawstring的作用。正则表达式是一种方便检查字符串匹配模式的工具,通过本文的学习可以掌握Python中使用正则表达式的基本方法。 ... [详细]
  • 本文介绍了Android 7的学习笔记总结,包括最新的移动架构视频、大厂安卓面试真题和项目实战源码讲义。同时还分享了开源的完整内容,并提醒读者在使用FileProvider适配时要注意不同模块的AndroidManfiest.xml中配置的xml文件名必须不同,否则会出现问题。 ... [详细]
  • Go Cobra命令行工具入门教程
    本文介绍了Go语言实现的命令行工具Cobra的基本概念、安装方法和入门实践。Cobra被广泛应用于各种项目中,如Kubernetes、Hugo和Github CLI等。通过使用Cobra,我们可以快速创建命令行工具,适用于写测试脚本和各种服务的Admin CLI。文章还通过一个简单的demo演示了Cobra的使用方法。 ... [详细]
  • web.py开发web 第八章 Formalchemy 服务端验证方法
    本文介绍了在web.py开发中使用Formalchemy进行服务端表单数据验证的方法。以User表单为例,详细说明了对各字段的验证要求,包括必填、长度限制、唯一性等。同时介绍了如何自定义验证方法来实现验证唯一性和两个密码是否相等的功能。该文提供了相关代码示例。 ... [详细]
  • Python爬虫中使用正则表达式的方法和注意事项
    本文介绍了在Python爬虫中使用正则表达式的方法和注意事项。首先解释了爬虫的四个主要步骤,并强调了正则表达式在数据处理中的重要性。然后详细介绍了正则表达式的概念和用法,包括检索、替换和过滤文本的功能。同时提到了re模块是Python内置的用于处理正则表达式的模块,并给出了使用正则表达式时需要注意的特殊字符转义和原始字符串的用法。通过本文的学习,读者可以掌握在Python爬虫中使用正则表达式的技巧和方法。 ... [详细]
  • 树莓派语音控制的配置方法和步骤
    本文介绍了在树莓派上实现语音控制的配置方法和步骤。首先感谢博主Eoman的帮助,文章参考了他的内容。树莓派的配置需要通过sudo raspi-config进行,然后使用Eoman的控制方法,即安装wiringPi库并编写控制引脚的脚本。具体的安装步骤和脚本编写方法在文章中详细介绍。 ... [详细]
  • iOS超签签名服务器搭建及其优劣势
    本文介绍了搭建iOS超签签名服务器的原因和优势,包括不掉签、用户可以直接安装不需要信任、体验好等。同时也提到了超签的劣势,即一个证书只能安装100个,成本较高。文章还详细介绍了超签的实现原理,包括用户请求服务器安装mobileconfig文件、服务器调用苹果接口添加udid等步骤。最后,还提到了生成mobileconfig文件和导出AppleWorldwideDeveloperRelationsCertificationAuthority证书的方法。 ... [详细]
  • 本文介绍了Oracle存储过程的基本语法和写法示例,同时还介绍了已命名的系统异常的产生原因。 ... [详细]
  • 使用圣杯布局模式实现网站首页的内容布局
    本文介绍了使用圣杯布局模式实现网站首页的内容布局的方法,包括HTML部分代码和实例。同时还提供了公司新闻、最新产品、关于我们、联系我们等页面的布局示例。商品展示区包括了车里子和农家生态土鸡蛋等产品的价格信息。 ... [详细]
  • 十大经典排序算法动图演示+Python实现
    本文介绍了十大经典排序算法的原理、演示和Python实现。排序算法分为内部排序和外部排序,常见的内部排序算法有插入排序、希尔排序、选择排序、冒泡排序、归并排序、快速排序、堆排序、基数排序等。文章还解释了时间复杂度和稳定性的概念,并提供了相关的名词解释。 ... [详细]
  • Gitlab接入公司内部单点登录的安装和配置教程
    本文介绍了如何将公司内部的Gitlab系统接入单点登录服务,并提供了安装和配置的详细教程。通过使用oauth2协议,将原有的各子系统的独立登录统一迁移至单点登录。文章包括Gitlab的安装环境、版本号、编辑配置文件的步骤,并解决了在迁移过程中可能遇到的问题。 ... [详细]
  • Ihaveaworkfolderdirectory.我有一个工作文件夹目录。holderDir.glob(*)>holder[ProjectOne, ... [详细]
  • Commit1ced2a7433ea8937a1b260ea65d708f32ca7c95eintroduceda+Clonetraitboundtom ... [详细]
  • 本文介绍了在Java中检查字符串是否仅包含数字的方法,包括使用正则表达式的示例代码,并提供了测试案例进行验证。同时还解释了Java中的字符转义序列的使用。 ... [详细]
author-avatar
泉州多棱汽车销售服务有限公司
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有