热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

考虑关系的图卷积神经网络RGCN的一些理解以及DGL官方代码的一些讲解

考虑关系的图卷积神经网络R-GCN的一些理解以及DGL官方代码的一些讲解-文章目录前言R-GCN传播公式正则化DGL中的R-GCN实体分类的实例nn.Parameter





文章目录


  • 前言
  • R-GCN
    • 传播公式
    • 正则化

  • DGL中的R-GCN实体分类的实例
    • nn.Parameter
    • torch.matmul

  • 参考


前言

昨天写的GCN的一篇文章入榜了,可喜可贺。但是感觉距离我的目标还是有点远,因为最后要用R-GAT,我感觉可能得再懂一点R-GCN和GAT才可能比较好的理解R-GAT,今天就尝试一下把R-GCN搞搞清楚吧(至少得读懂DGL官方给的代码吧)


R-GCN

R-GCN和GCN的区别就在于这个R。R-GCN考虑了关系对消息传播的影响。
带你快速理解R-GCN(relational-GCN)
看看b站这个视频就应该能够比较形象的知道R-GCN是有什么创新点了,讲的还是非常不错的。

R-GCN解决了什么问题,可以这么形象的理解:

例如假设节点之间有那么些关系,如果不考虑边,仅仅GCN的话,那么1节点更新参数时,2节点和3节点传来的特征比例应该是一样的,那么就不太合理了,因为你的仇人和你的朋友怎么能够一概而论呢?

R-GCN就是考虑到了这一点,因此将关系加入考虑特征更新的操作。


传播公式


其中
σ()是激活函数
hj(l)是值输入的邻居节点的特征
hi(l) 是指输入的自己节点的特征
hi(l+1) 是指更新后的自己节点的特征
Wrl是对应关系特征的权重矩阵,因为考虑不同的关系,那么有多少种关系就有多少个Wr
W0就是自己的特征权重矩阵
ci,r是正则化常量

虽然这个公式看起来复杂,但是其实就是考虑了各种不同关系的GCN,GCN的是通过权重*特征来计算节点特征,权重就是通过度来计算出来;而R-GCN则是通过关系矩阵(Wrl) * hjl对应的特征矩阵来计算出对应节点的特征,Wr的个数就是图中所有节点之间的关系数+1(自连也算一种关系,需要一个即W0)。


正则化


正则化的意义就是在实验中发现,当图变大一点的时候,其关系的数量会迅速增长,从而导致模型训练困难,不仅如此,可能某些边的训练数据会不那么多,就会很容易导致过拟合的产生,基于此,论文中提到了一种正则化的技巧。

其核心思想就是对于每个关系r,共享参数Vb,训练的只是arb,这样共享参数的方式就减少了训练参数的数量,也缓解了过拟合。

对应到代码里,应该就是相当于将关系矩阵降维成一个小的关系矩阵之后,再进行特征的计算。
例如原来的关系矩阵可能是1000 * 1000,特征值维度为300 ,那么接个线性层 1000 * 10就将1000种关系转换成了10维来表示,然后再接一个10 * 300的线性层得出对应特征的权重,再得出特征向量的值即可完成考虑关系得出对应节点特征值的操作,然后汇聚即可。

如果这样的话,训练的参数应该是 1000 * 10 + 10 * 300
如果不这样,直接计算,那么训练的参数应该 1000 * 300
这个就是他正则化的好处了吧,应该是这样吧,我从代码里感觉出是这样的。


DGL中的R-GCN实体分类的实例


模型流程图就是这么简单

谁能想到,那代码拷下了居然有小bug,可能版本问题吧,但是稍微改改就可以跑了。

环境配置:
dgl0.6.1
torch
1.9.1

理解代码前还是需要学习一下某些相关的操作的。


nn.Parameter

import torch.nn as nn
a = nn.Parameter(torch.Tensor(3, 2, 2))
print(a)
b = a[torch.tensor([1, 1, 2, 2])]
print(b)
print(b.shape)

输出:

Parameter containing:
tensor([[[1.0194e-38, 9.6429e-39],
[9.2755e-39, 9.1837e-39]],
[[9.3674e-39, 1.0745e-38],
[1.0653e-38, 9.5510e-39]],
[[1.0561e-38, 1.0194e-38],
[1.1112e-38, 1.0561e-38]]], requires_grad=True)
tensor([[[9.3674e-39, 1.0745e-38],
[1.0653e-38, 9.5510e-39]],
[[9.3674e-39, 1.0745e-38],
[1.0653e-38, 9.5510e-39]],
[[1.0561e-38, 1.0194e-38],
[1.1112e-38, 1.0561e-38]],
[[1.0561e-38, 1.0194e-38],
[1.1112e-38, 1.0561e-38]]], grad_fn=)
torch.Size([4, 2, 2])

这个操作就是对能够更新的参数矩阵a中,索引出a里面的各个元素。
(看一眼就懂了,以前没这么用过,看到这代码开始的时候看不太懂)


torch.matmul

这个就是矩阵乘法,举两个例子就大概知道它有啥特别的了:

import torch
a = torch.tensor([[1, 2],
[2, 2]])
b = torch.tensor([[[1, 4],
[2, 4]],
[[1, 5],
[2, 4]],
[[1, 5],
[2, 4]]])
c = torch.matmul(a, b)
print(a.shape)
print(b.shape)
print(c)

输出结果:

torch.Size([2, 2])
torch.Size([3, 2, 2])
tensor([[[ 5, 12],
[ 6, 16]],
[[ 5, 13],
[ 6, 18]],
[[ 5, 13],
[ 6, 18]]])


import torch
a = torch.tensor([[[1, 2],
[2, 2]],])
b = torch.tensor([[[1, 4],
[2, 4]],
[[1, 5],
[2, 4]],
[[1, 5],
[2, 4]]])
c = torch.matmul(a, b)
print(a.shape)
print(b.shape)
print(c)

输出结果

torch.Size([1, 2, 2])
torch.Size([3, 2, 2])
tensor([[[ 5, 12],
[ 6, 16]],
[[ 5, 13],
[ 6, 18]],
[[ 5, 13],
[ 6, 18]]])


import torch
a = torch.tensor([[[1, 2],
[2, 2]],
[[1, 5],
[2, 4]]])
b = torch.tensor([[[1, 4],
[2, 4]],
[[1, 5],
[2, 4]],
])
c = torch.matmul(a, b)
print(a.shape)
print(b.shape)
print(c)

输出结果:

torch.Size([2, 2, 2])
torch.Size([2, 2, 2])
tensor([[[ 5, 12],
[ 6, 16]],
[[11, 25],
[10, 26]]])


就是当两个三维矩阵第一维维度相同,那么就能对应元素进行矩阵乘法,如果第一维维度为1或者缺少一维,那么就是直接进行矩阵乘法,如果第一维维度不同且都不为1,那么就会报错。

看了一下午的R-GCN的代码,我感觉应该差不多可以理解代码的思路了,千言万语都在注释里了,并且可直接运行(如果版本和我匹配的话)

import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
import dgl.function as fn
from functools import partial
class RGCNLayer(nn.Module):
def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None,
activation=None, is_input_layer=False):
super(RGCNLayer, self).__init__()
# 输入的特征维度
self.in_feat = in_feat
# 输出的特征维度
self.out_feat = out_feat
# 关系的数量
self.num_rels = num_rels
# 基分解中W_r分解的数量,即B的大小
self.num_bases = num_bases
# 是否带偏置b
self.bias = bias
# 激活函数
self.activation = activation
# 是否是输入层
self.is_input_layer = is_input_layer
# 如果说没设定W_r的个数(B)或者W_r的个数比关系数大,那么就直接取关系数
# 因为这个正则化就是为了解决关系数过多而导致训练参数过多及过拟合的问题的
# 如果没有正则化优化正常来说有几个关系就对应几个W_r
# 因为因此肯定是B num_rels * num_bases
# torch.matmul(self.w_comp, weight)
# w_comp(num_rels * num_bases) weight(in_feat * num_bases * out_feat)
# ||
# V
# in_feat * num_rels * out_feat
# 再经过view操作
# weight --> num_rels * in_feat * out_feat
weight = torch.matmul(self.w_comp, weight).view(self.num_rels,
self.in_feat, self.out_feat)
else:
# 如果没有正则化,即直接取所有关系,原本初始化就是这个形状
# weight = num_rels * in_feat * out_feat
weight = self.weight
if self.is_input_layer:
# 如果是输入层,需要获得节点的embedding表达
def message_func(edges):
# for input layer, matrix multiply can be converted to be
# an embedding lookup using source node id
# 就这个例子来说,weight: num_rels * in_feat * out_feat = 91 * 8285 * 16
embed = weight.view(-1, self.out_feat) # embed = 753935 * 16
# 这句话真的看了半天才懂一点点
# edges.data['rel_type']存的是所有的关系 共65439个关系
# self.in_feat是输入的embedding,因为这里直接one-hot表示,因此大小为8285
# edges.src['id']是每个关系的源节点id, 这个参数看了好久也不知道是怎么来的,我感觉可能是节点有了“id”,然后传进来便就会有这个参数了吧
# 这个index还是很奇妙的
# 因为一共91种关系,8285个节点,那么最开始输入需要赋值节点自身的特征
# 那么每个节点对应的91种关系都有不同的表达
# edges.data['rel_type'] * self.in_feat + edges.src['id']就是取自身节点对应的关系的那种表达
# 从而获得节点的embedding表达
index = edges.data['rel_type'] * self.in_feat + edges.src['id']
return {'msg': embed[index] * edges.data['norm']}
else:
# 如果不是输入层那么用计算出 邻居特征*关系 的特征值
def message_func(edges):
# 取出对应关系的权重矩阵
w = weight[edges.data['rel_type'].long()]
# 矩阵乘法获取每条边需要传递的特征msg:(65439 * 4)
msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
msg = msg * edges.data['norm']
return {'msg': msg}
# msg求和作为节点特征
# 有偏置加偏置。
# 有激活加激活,主要是用于输出层设置的
def apply_func(nodes):
h = nodes.data['h']
if self.bias:
h = h + self.bias
if self.activation:
h = self.activation(h)
return {'h': h}
g.update_all(message_func, fn.sum(msg='msg', out='h'), apply_func)
class Model(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels,
num_bases=-1, num_hidden_layers=1):
super(Model, self).__init__()
self.num_nodes = num_nodes
self.h_dim = h_dim
self.out_dim = out_dim
self.num_rels = num_rels
self.num_bases = num_bases
self.num_hidden_layers = num_hidden_layers
# 创建R-GCN层
self.build_model()
# 获取特征
self.features = self.create_features()
def build_model(self):
self.layers = nn.ModuleList()
# 输入层
i2h = self.build_input_layer()
self.layers.append(i2h)
# 隐藏层
for _ in range(self.num_hidden_layers):
h2h = self.build_hidden_layer()
self.layers.append(h2h)
# 输出层
h2o = self.build_output_layer()
self.layers.append(h2o)
# 初始胡化每个节点的特征
def create_features(self):
features = torch.arange(self.num_nodes)
return features
# 构建输入层
def build_input_layer(self):
return RGCNLayer(self.num_nodes, self.h_dim, self.num_rels, self.num_bases,
activation=F.relu, is_input_layer=True)
# 构建隐藏层
def build_hidden_layer(self):
return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases,
activation=F.relu)
# 构建输出层
def build_output_layer(self):
return RGCNLayer(self.h_dim, self.out_dim, self.num_rels, self.num_bases,
activation=partial(F.softmax, dim=1))
# 前向传播
def forward(self, g):
if self.features is not None:
g.ndata['id'] = self.features
for layer in self.layers:
layer(g)
# 取出每个节点的隐藏层并且删除"h"特征,方便下一次进行训练
return g.ndata.pop('h')
from dgl.contrib.data import load_data
data = load_data(dataset='aifb')
num_nodes = data.num_nodes # 节点数量
num_rels = data.num_rels # 关系数量
num_classes = data.num_classes # 分类的类别数
labels = data.labels # 标签
train_idx = data.train_idx # 训练集节点的index
# split training and validation set
val_idx = train_idx[:len(train_idx) // 5] # 划分验证集
train_idx = train_idx[len(train_idx) // 5:] # 划分训练集
edge_type = torch.from_numpy(data.edge_type) # 获取边的类型
edge_norm = torch.from_numpy(data.edge_norm).unsqueeze(1) # 获取边的标准化因子
labels = torch.from_numpy(labels).view(-1)
# configurations
n_hidden = 16 # 每层的神经元个数
n_bases = -1 # 直接用所有的关系,不正则化
n_hidden_layers = 0 # 使用一层输入一层输出,不用隐藏层
n_epochs = 25 # 训练次数
lr = 0.01 # 学习率
# 创建图
g = DGLGraph((data.edge_src, data.edge_dst))
g.edata.update({'rel_type': edge_type, 'norm': edge_norm})
# 创建模型
model = Model(len(g),
n_hidden,
num_classes,
num_rels,
num_bases=n_bases,
num_hidden_layers=n_hidden_layers)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2norm)
print("start training...")
model.train()
for epoch in range(n_epochs):
optimizer.zero_grad()
logits = model.forward(g)
loss = F.cross_entropy(logits[train_idx], labels[train_idx].long())
loss.backward()
optimizer.step()
train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx])
train_acc = train_acc.item() / len(train_idx)
val_loss = F.cross_entropy(logits[val_idx], labels[val_idx].long())
val_acc = torch.sum(logits[val_idx].argmax(dim=1) == labels[val_idx])
val_acc = val_acc.item() / len(val_idx)
print("Epoch {:05d} | ".format(epoch) +
"Train Accuracy: {:.4f} | Train Loss: {:.4f} | ".format(
train_acc, loss.item()) +
"Validation Accuracy: {:.4f} | Validation loss: {:.4f}".format(
val_acc, val_loss.item()))

其实代码并没有完完全全读懂,还是有一些地方有疑惑的,有不对的地方,希望大家能够批评指正!


参考

https://docs.dgl.ai/en/0.6.x/tutorials/models/1_gnn/4_rgcn.html
Modeling Relational Data with Graph Convolutional Networks





推荐阅读
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • vue使用
    关键词: ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • 本文介绍了[从头学数学]中第101节关于比例的相关问题的研究和修炼过程。主要内容包括[机器小伟]和[工程师阿伟]一起研究比例的相关问题,并给出了一个求比例的函数scale的实现。 ... [详细]
  • 利用Visual Basic开发SAP接口程序初探的方法与原理
    本文介绍了利用Visual Basic开发SAP接口程序的方法与原理,以及SAP R/3系统的特点和二次开发平台ABAP的使用。通过程序接口自动读取SAP R/3的数据表或视图,在外部进行处理和利用水晶报表等工具生成符合中国人习惯的报表样式。具体介绍了RFC调用的原理和模型,并强调本文主要不讨论SAP R/3函数的开发,而是针对使用SAP的公司的非ABAP开发人员提供了初步的接口程序开发指导。 ... [详细]
  • Html5-Canvas实现简易的抽奖转盘效果
    本文介绍了如何使用Html5和Canvas标签来实现简易的抽奖转盘效果,同时使用了jQueryRotate.js旋转插件。文章中给出了主要的html和css代码,并展示了实现的基本效果。 ... [详细]
  • 深度学习中的Vision Transformer (ViT)详解
    本文详细介绍了深度学习中的Vision Transformer (ViT)方法。首先介绍了相关工作和ViT的基本原理,包括图像块嵌入、可学习的嵌入、位置嵌入和Transformer编码器等。接着讨论了ViT的张量维度变化、归纳偏置与混合架构、微调及更高分辨率等方面。最后给出了实验结果和相关代码的链接。本文的研究表明,对于CV任务,直接应用纯Transformer架构于图像块序列是可行的,无需依赖于卷积网络。 ... [详细]
  • IjustinheritedsomewebpageswhichusesMooTools.IneverusedMooTools.NowIneedtoaddsomef ... [详细]
  • 本文介绍了绕过WAF的XSS检测机制的方法,包括确定payload结构、测试和混淆。同时提出了一种构建XSS payload的方法,该payload与安全机制使用的正则表达式不匹配。通过清理用户输入、转义输出、使用文档对象模型(DOM)接收器和源、实施适当的跨域资源共享(CORS)策略和其他安全策略,可以有效阻止XSS漏洞。但是,WAF或自定义过滤器仍然被广泛使用来增加安全性。本文的方法可以绕过这种安全机制,构建与正则表达式不匹配的XSS payload。 ... [详细]
  • 小程序wxs中的时间格式化以及格式化时间和date时间互转
    本文介绍了在小程序wxs中进行时间格式化操作的问题,并提供了解决方法。同时还介绍了格式化时间和date时间的互相转换的方法。 ... [详细]
  • 本文介绍了一个从入门到高手的VB.NET源代码,通过学习这些源代码,可以在21天内成为VB.NET高手。文章提供了下载地址,并提醒读者加入作者的QQ群和收藏作者的博客。 ... [详细]
  • 本文整理了315道Python基础题目及答案,帮助读者检验学习成果。文章介绍了学习Python的途径、Python与其他编程语言的对比、解释型和编译型编程语言的简述、Python解释器的种类和特点、位和字节的关系、以及至少5个PEP8规范。对于想要检验自己学习成果的读者,这些题目将是一个不错的选择。请注意,答案在视频中,本文不提供答案。 ... [详细]
  • loader资源模块加载器webpack资源模块加载webpack内部(内部loader)默认只会处理javascript文件,也就是说它会把打包过程中所有遇到的 ... [详细]
  • 颜色迁移(reinhard VS welsh)
    不要谈什么天分,运气,你需要的是一个截稿日,以及一个不交稿就能打爆你狗头的人,然后你就会被自己的才华吓到。------ ... [详细]
  • 获取时间的函数js代码,js获取时区代码
    本文目录一览:1、js获取服务器时间(动态)2 ... [详细]
author-avatar
你走之后你的美我如何收拾_686
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有