热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

关于算法:图神经网络之预训练大模型结合ERNIESage在链接预测任务应用

在很多工业利用中,往往呈现如下图所示的一种非凡的图:TextGraph。顾名思义,图的节点属性由文本形成,而边的构建提供了构造信息。如搜寻场景下的TextGrap

1.ERNIESage运行实例介绍(1.8x版本)

本我的项目原链接:https://aistudio.baidu.com/aistudio/projectdetail/5097085?cOntributionType=1

本我的项目次要是为了间接提供一个能够运行ERNIESage模型的环境,

https://github.com/PaddlePadd…

在很多工业利用中,往往呈现如下图所示的一种非凡的图:Text Graph。顾名思义,图的节点属性由文本形成,而边的构建提供了构造信息。如搜寻场景下的Text Graph,节点可由搜索词、网页题目、网页注释来表白,用户反馈和超链信息则可形成边关系。

ERNIESage 由PGL团队提出,是ERNIE SAmple aggreGatE的简称,该模型能够同时建模文本语义与图构造信息,无效晋升 Text Graph 的利用成果。其中 ERNIE 是百度推出的基于常识加强的继续学习语义了解框架。

ERNIESage 是 ERNIE 与 GraphSAGE 碰撞的后果,是 ERNIE SAmple aggreGatE 的简称,它的构造如下图所示,次要思维是通过 ERNIE 作为聚合函数(Aggregators),建模本身节点和街坊节点的语义与构造关系。ERNIESage 对于文本的建模是构建在街坊聚合的阶段,核心节点文本会与所有街坊节点文本进行拼接;而后通过预训练的 ERNIE 模型进行音讯汇聚,捕获核心节点以及街坊节点之间的互相关系;最初应用 ERNIESage 搭配独特的街坊相互看不见的 Attention Mask 和独立的 Position Embedding 体系,就能够轻松构建 TextGraph 中句子之间以及词之间的关系。

应用ID特色的GraphSAGE只可能建模图的构造信息,而独自的ERNIE只能解决文本信息。通过PGL搭建的图与文本的桥梁,ERNIESage可能很简略的把GraphSAGE以及ERNIE的长处联合一起。以上面TextGraph的场景,ERNIESage的成果可能比独自的ERNIE以及GraphSAGE模型都要好。

ERNIESage能够很轻松地在PGL中的消息传递范式中进行实现,目前PGL在github上提供了3个版本的ERNIESage模型:

  • ERNIESage v1: ERNIE 作用于text graph节点上;
  • ERNIESage v2: ERNIE 作用在text graph的边上;
  • ERNIESage v3: ERNIE 作用于一阶街坊及起边上;

次要会针对ERNIESageV1和ERNIESageV2版本进行一个介绍。

1.1算法实现

可能有同学对于整个我的项目代码文件都不太理解,因而这里会做一个比较简单的解说。

外围局部蕴含:

  • 数据集局部
  • data.txt – 简略的输出文件,格局为每行query \t answer,可作简略的运行实例应用。
  • 模型文件和配置局部
  • ernie_config.json – ERNIE模型的配置文件。
  • vocab.txt – ERNIE模型所应用的词表。
  • ernie_base_ckpt/ – ERNIE模型参数。
  • config/ – ERNIESage模型的配置文件,蕴含了三个版本的配置文件。
  • 代码局部
  • local_run.sh – 入口文件,通过该入口可实现预处理、训练、infer三个步骤。
  • preprocessing文件夹 – 蕴含dump_graph.py, tokenization.py。在预处理局部,咱们首先须要进行建图,将输出的文件构建成一张图。因为咱们所钻研的是Text Graph,因而节点都是文本,咱们将文本示意为该节点对应的node feature(节点特色),解决文本的时候须要进行切字,再映射为对应的token id。
  • dataset/ – 该文件夹蕴含了数据ready的代码,以便于咱们在训练的时候将训练数据以batch的形式读入。
  • models/ – 蕴含了ERNIESage模型外围代码。
  • train.py – 模型训练入口文件。
  • learner.py – 分布式训练代码,通过train.py调用。
  • infer.py – infer代码,用于infer出节点对应的embedding。
  • 评估局部
  • build_dev.py – 用于将咱们的验证集批改为须要的格局。
  • mrr.py – 计算MRR值。

要在这个我的项目中运行模型其实很简略,只有运行下方的入口命令就ok啦!然而,须要留神的是,因为ERNIESage模型比拟大,所以如果AIStudio中的CPU版本运行模型容易出问题。因而,在运行部署环境时,倡议抉择GPU的环境。

另外,如果提醒呈现了GPU空间有余等问题,咱们能够通过调小对应yaml文件中的batch_size来调整,也能够批改ERNIE模型的配置文件ernie_config.json,将num_hidden_layers设小一些。在这里,我仅提供了ERNIESageV2版本的gpu运行过程,如果同学们想运行其余版本的模型,能够依据须要批改下方的命令。

运行结束后,会产生较多的文件,这里进行简略的解释。

  1. workdir/ – 这个文件夹次要会存储和图相干的数据信息。
  2. output/ – 次要的输入文件夹,蕴含了以下内容:(1)模型文件,依据config文件中的save_per_step可调整保留模型的频率,如果设置得比拟大则可能训练过程中不会保留模型; (2)last文件夹,保留了进行训练时的模型参数,在infer阶段咱们会应用这部分模型参数;(3)part-0文件,infer之后的输出文件中所有节点的Embedding输入。

为了能够比较清楚地晓得Embedding的成果,咱们间接通过MRR简略判断一下data.txt计算出来的Embedding后果,此处将data.txt同时作为训练集和验证集。

1.2 外围模型代码解说

首先,咱们能够通过查看models/model_factory.py来判断在本我的项目有多少种ERNIESage模型。

from models.base import BaseGNNModel
from models.ernie import ErnieModel
from models.erniesage_v1 import ErnieSageModelV1
from models.erniesage_v2 import ErnieSageModelV2
from models.erniesage_v3 import ErnieSageModelV3

class Model(object):
    @classmethod
    def factory(cls, config):
        name = config.model_type
        if name == "BaseGNNModel":
            return BaseGNNModel(config)
        if name == "ErnieModel":
            return ErnieModel(config)
        if name == "ErnieSageModelV1":
            return ErnieSageModelV1(config)
        if name == "ErnieSageModelV2":
            return ErnieSageModelV2(config)
        if name == "ErnieSageModelV3":
            return ErnieSageModelV3(config)
        else:
            raise ValueError

能够看到一共有ERNIESage模型一共有3个版本,另外咱们也提供了根本的GNN模型和ERNIE模型,感兴趣的同学能够自行查阅。

接下来,我次要会针对ERNIESageV1和ERNIESageV2这两个版本的模型进行要害局部的解说,次要的不同其实就是消息传递机制(Message Passing)局部的不同。

1.2.1 ERNIESageV1要害代码

# ERNIESageV1的Message Passing代码
# 查找门路:erniesage_v1.py(__call__中的self.gnn_layers) -> base.py(BaseNet类中的gnn_layers办法) -> message_passing.py

# erniesage_v1.py
def __call__(self, graph_wrappers):
    inputs = self.build_inputs()
    feature = self.build_embedding(graph_wrappers, inputs[-1])  # 将节点的文本信息利用ERNIE模型建模,生成对应的Embedding作为feature
    features = self.gnn_layers(graph_wrappers, feature)  # GNN模型的次要不同,消息传递机制入口
    outputs = [self.take_final_feature(features[-1], i, "final_fc") for i in inputs[:-1]]
    src_real_index = L.gather(graph_wrappers[0].node_feat['index'], inputs[0])
    outputs.append(src_real_index)
    return inputs, outputs

# base.py -> BaseNet
def gnn_layers(self, graph_wrappers, feature):
    features = [feature]
    initializer = None
    fc_lr = self.config.lr / 0.001
    for i in range(self.config.num_layers):
        if i == self.config.num_layers - 1:
            act = None
        else:
            act = "leaky_relu"
        feature = get_layer(  
            self.config.layer_type, # 对于ERNIESageV1, 其layer_type="graphsage_sum",能够到config文件夹中查看
            graph_wrappers[i],
            feature,
            self.config.hidden_size,
            act,
            initializer,
            learning_rate=fc_lr,
            name="%s_%s" % (self.config.layer_type, i))
        features.append(feature)
    return features

# message_passing.py
def graphsage_sum(gw, feature, hidden_size, act, initializer, learning_rate, name):
    """doc"""
    msg = gw.send(copy_send, nfeat_list=[("h", feature)]) # Send
    neigh_feature = gw.recv(msg, sum_recv)                # Recv
    self_feature = feature
    self_feature = fluid.layers.fc(self_feature,
                                   hidden_size,
                                   act=act,
                                   param_attr=fluid.ParamAttr(name=name + "_l.w_0", initializer=initializer,
                                   learning_rate=learning_rate),
                                    bias_attr=name+"_l.b_0"
                                   )
    neigh_feature = fluid.layers.fc(neigh_feature,
                                    hidden_size,
                                    act=act,
                                    param_attr=fluid.ParamAttr(name=name + "_r.w_0", initializer=initializer,
                                   learning_rate=learning_rate),
                                    bias_attr=name+"_r.b_0"
                                    )
    output = fluid.layers.concat([self_feature, neigh_feature], axis=1)
    output = fluid.layers.l2_normalize(output, axis=1)
    return output

通过上述代码片段能够看到,要害的消息传递机制代码就是graphsage_sum函数,其中send、recv局部如下。

def copy_send(src_feat, dst_feat, edge_feat):
    """doc"""
    return src_feat["h"]
    
msg = gw.send(copy_send, nfeat_list=[("h", feature)]) # Send
neigh_feature = gw.recv(msg, sum_recv)                # Recv

通过代码能够看到,ERNIESageV1版本,其次要是针对节点街坊,间接将以后节点的街坊节点特色求和。再看到graphsage_sum函数中,将街坊节点特色进行求和后,失去了neigh_feature。随后,咱们将节点自身的特色self_feature和街坊聚合特色neigh_feature通过fc层后,间接concat起来,从而失去了以后gnn layer层的feature输入。

1.2.2ERNIESageV2要害代码

ERNIESageV2的消息传递机制代码次要在erniesage_v2.py和message_passing.py,绝对ERNIESageV1来说,代码会绝对长了一些。

为了使得大家对上面无关ERNIE模型的局部可能有所理解,这里先贴出ERNIE的主模型框架图。

具体的代码解释能够间接看正文。

# ERNIESageV2的Message Passing代码

# 上面的函数都在erniesage_v2.py的ERNIESageV2类中
# ERNIESageV2的调用函数
def __call__(self, graph_wrappers):
    inputs = self.build_inputs()
    feature = inputs[-1]
    features = self.gnn_layers(graph_wrappers, feature) 
    outputs = [self.take_final_feature(features[-1], i, "final_fc") for i in inputs[:-1]]
    src_real_index = L.gather(graph_wrappers[0].node_feat['index'], inputs[0])
    outputs.append(src_real_index)
    return inputs, outputs

# 进入self.gnn_layers函数
def gnn_layers(self, graph_wrappers, feature):
    features = [feature]

    initializer = None
    fc_lr = self.config.lr / 0.001

    for i in range(self.config.num_layers):
        if i == self.config.num_layers - 1:
            act = None
        else:
            act = "leaky_relu"

        feature = self.gnn_layer(
            graph_wrappers[i],
            feature,
            self.config.hidden_size,
            act,
            initializer,
            learning_rate=fc_lr,
            name="%s_%s" % ("erniesage_v2", i))
        features.append(feature)
    return features
接下来会进入ERNIESageV2次要的代码局部。

能够看到,在ernie_send函数用于将咱们的街坊信息发送到以后节点。在ERNIESageV1中,咱们在Send阶段对街坊节点通过ERNIE模型失去Embedding后,再间接求和,实际上以后节点和街坊节点之间的文本信息在消息传递过程中是没有间接交互的,直到最初才**concat**起来;而ERNIESageV2中,在Send阶段,源节点和指标节点的信息会间接concat起来,通过ERNIE模型失去一个对立的Embedding,这样就失去了源节点和指标节点的一个信息交互过程,这个局部能够查看上面的ernie_send函数。

gnn_layer函数中蕴含了三个函数:
1. ernie_send: 将src和dst节点对应文本concat后,过Ernie后失去须要的msg,更加具体的解释能够看下方代码正文。
2. build_position_ids: 次要是为了创立地位ID,提供给Ernie,从而能够产生position embeddings。
3. erniesage_v2_aggregator: gnn_layer的入口函数,蕴含了消息传递机制,以及聚合后的音讯feature处理过程。
# 进入self.gnn_layer函数
def gnn_layer(self, gw, feature, hidden_size, act, initializer, learning_rate, name):
    def build_position_ids(src_ids, dst_ids): # 此函数用于创立地位ID,能够对应到ERNIE框架图中的Position Embeddings
        # ...
        pass
    def ernie_send(src_feat, dst_feat, edge_feat): 
        """doc"""
        # input_ids,能够对应到ERNIE框架图中的Token Embeddings
        cls = L.fill_constant_batch_size_like(src_feat["term_ids"], [-1, 1, 1], "int64", 1)
        src_ids = L.concat([cls, src_feat["term_ids"]], 1)
        dst_ids = dst_feat["term_ids"]
        term_ids = L.concat([src_ids, dst_ids], 1)

        # sent_ids,能够对应到ERNIE框架图中的Segment Embeddings
        sent_ids = L.concat([L.zeros_like(src_ids), L.ones_like(dst_ids)], 1)
        
        # position_ids,能够对应到ERNIE框架图中的Position Embeddings
        position_ids = build_position_ids(src_ids, dst_ids)

        term_ids.stop_gradient = True
        sent_ids.stop_gradient = True
        ernie = ErnieModel( # ERNIE模型
            term_ids, sent_ids, position_ids,
            cOnfig=self.config.ernie_config)
        feature = ernie.get_pooled_output() # 失去发送过去的msg,该msg是由src节点和dst节点的文本特色一起过ERNIE后失去的embedding
        return feature
    def erniesage_v2_aggregator(gw, feature, hidden_size, act, initializer, learning_rate, name):
        feature = L.unsqueeze(feature, [-1])
        msg = gw.send(ernie_send, nfeat_list=[("term_ids", feature)]) # Send
        neigh_feature = gw.recv(msg, lambda feat: F.layers.sequence_pool(feat, pool_type="sum")) # Recv,间接将发送来的msg依据dst节点来相加。
        
        # 接下来的局部和ERNIESageV1相似,将self_feature和neigh_feature通过concat、normalize后失去须要的输入。
        term_ids = feature
        cls = L.fill_constant_batch_size_like(term_ids, [-1, 1, 1], "int64", 1)
        term_ids = L.concat([cls, term_ids], 1)
        term_ids.stop_gradient = True
        ernie = ErnieModel(
            term_ids, L.zeros_like(term_ids),
            cOnfig=self.config.ernie_config)
        self_feature = ernie.get_pooled_output()
        self_feature = L.fc(self_feature,
                                        hidden_size,
                                        act=act,
                                        param_attr=F.ParamAttr(name=name + "_l.w_0",
                                        learning_rate=learning_rate),
                                        bias_attr=name+"_l.b_0"
                                        )
        neigh_feature = L.fc(neigh_feature,
                                        hidden_size,
                                        act=act,
                                        param_attr=F.ParamAttr(name=name + "_r.w_0",
                                        learning_rate=learning_rate),
                                        bias_attr=name+"_r.b_0"
                                        )
        output = L.concat([self_feature, neigh_feature], axis=1)
        output = L.l2_normalize(output, axis=1)
        return output
    return erniesage_v2_aggregator(gw, feature, hidden_size, act, initializer, learning_rate, name)
    

2.总结

通过以上两个版本的模型代码简略的解说,咱们能够晓得他们的不同点,其实次要就是在消息传递机制的局部有所不同。ERNIESageV1版本只作用在text graph的节点上,在传递音讯(Send阶段)时只思考了街坊自身的文本信息;而ERNIESageV2版本则作用在了边上,在Send阶段同时思考了以后节点和其街坊节点的文本信息,达到更好的交互成果。


推荐阅读
  • 先上图引入插件在pubspec.yaml中引入charts_flutter插件使用的时候版本到0.6.0,插件地址:https:github.comgooglecharts使用插件 ... [详细]
  • XML介绍与使用的概述及标签规则
    本文介绍了XML的基本概念和用途,包括XML的可扩展性和标签的自定义特性。同时还详细解释了XML标签的规则,包括标签的尖括号和合法标识符的组成,标签必须成对出现的原则以及特殊标签的使用方法。通过本文的阅读,读者可以对XML的基本知识有一个全面的了解。 ... [详细]
  • 在Docker中,将主机目录挂载到容器中作为volume使用时,常常会遇到文件权限问题。这是因为容器内外的UID不同所导致的。本文介绍了解决这个问题的方法,包括使用gosu和suexec工具以及在Dockerfile中配置volume的权限。通过这些方法,可以避免在使用Docker时出现无写权限的情况。 ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 云原生边缘计算之KubeEdge简介及功能特点
    本文介绍了云原生边缘计算中的KubeEdge系统,该系统是一个开源系统,用于将容器化应用程序编排功能扩展到Edge的主机。它基于Kubernetes构建,并为网络应用程序提供基础架构支持。同时,KubeEdge具有离线模式、基于Kubernetes的节点、群集、应用程序和设备管理、资源优化等特点。此外,KubeEdge还支持跨平台工作,在私有、公共和混合云中都可以运行。同时,KubeEdge还提供数据管理和数据分析管道引擎的支持。最后,本文还介绍了KubeEdge系统生成证书的方法。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • 标题: ... [详细]
  • Java 11相对于Java 8,OptaPlanner性能提升有多大?
    本文通过基准测试比较了Java 11和Java 8对OptaPlanner的性能提升。测试结果表明,在相同的硬件环境下,Java 11相对于Java 8在垃圾回收方面表现更好,从而提升了OptaPlanner的性能。 ... [详细]
  • 本文整理了Java面试中常见的问题及相关概念的解析,包括HashMap中为什么重写equals还要重写hashcode、map的分类和常见情况、final关键字的用法、Synchronized和lock的区别、volatile的介绍、Syncronized锁的作用、构造函数和构造函数重载的概念、方法覆盖和方法重载的区别、反射获取和设置对象私有字段的值的方法、通过反射创建对象的方式以及内部类的详解。 ... [详细]
  • docker+k8s+git+jenkins
    docker+k8s+git+jenkins,Go语言社区,Golang程序员人脉社 ... [详细]
  • 本文介绍了设计师伊振华受邀参与沈阳市智慧城市运行管理中心项目的整体设计,并以数字赋能和创新驱动高质量发展的理念,建设了集成、智慧、高效的一体化城市综合管理平台,促进了城市的数字化转型。该中心被称为当代城市的智能心脏,为沈阳市的智慧城市建设做出了重要贡献。 ... [详细]
  • baresip android编译、运行教程1语音通话
    本文介绍了如何在安卓平台上编译和运行baresip android,包括下载相关的sdk和ndk,修改ndk路径和输出目录,以及创建一个c++的安卓工程并将目录考到cpp下。详细步骤可参考给出的链接和文档。 ... [详细]
  • 如何用UE4制作2D游戏文档——计算篇
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了如何用UE4制作2D游戏文档——计算篇相关的知识,希望对你有一定的参考价值。 ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
  • 自动轮播,反转播放的ViewPagerAdapter的使用方法和效果展示
    本文介绍了如何使用自动轮播、反转播放的ViewPagerAdapter,并展示了其效果。该ViewPagerAdapter支持无限循环、触摸暂停、切换缩放等功能。同时提供了使用GIF.gif的示例和github地址。通过LoopFragmentPagerAdapter类的getActualCount、getActualItem和getActualPagerTitle方法可以实现自定义的循环效果和标题展示。 ... [详细]
author-avatar
mobiledu2502885307
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有