热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

无监督关键短语的生成问题博客06model.py的分析

2021SCSDUSC在上一篇博客中,我们分析了model.py中的Decoder类,并对LSTM作了简要的介绍,以实例来说明了构建模型

2021SC@SDUSC

在上一篇博客中,我们分析了model.py中的Decoder类,并对LSTM作了简要的介绍,以实例来说明了构建模型时的各个参数,以及引入了嵌入层。本篇博客我们将从RNN模型的分类入手,分析Seq2Seq模型的框架并分析本篇论文中的Seq2Seq代码。

Encoder-Decoder框架是一个End-to-End学习的算法。简单来说,Seq2Seq(即Sequence to Sequence)以一个Encoder来编码输入的Sequence,再以一个Decoder来输出Sequence。其大致框架是一个序列经过Encoderr得到一个隐状态,再通过这个隐状态使用Decoder得到最终需要的序列。

 图1:Seq2Seq模型示例


一、 RNN的分类

接下来我们将由RNN的分类引出Seq2Seq模型。按照输入和输出的结构,可以将RNN分为:


  • N vs N - RNN
  • N vs 1 - RNN
  • 1 vs N - RNN
  • N vs M - RNN 

 1. N vs N - RNN

 图2:N vs N - RNN图示

这是RNN最基础的结构形式,其特点是输出和输入序列是等长的,但也由于这个限制的存在,其适用范围比较小,可以用于生成等长的诗句。


2. N vs 1 - RNN

  图3:N vs 1 - RNN图示

有时候我们需要处理的问题输入是一个序列,但输出的是一个单独的值而不是序列,我们只需要在隐藏层输出h上进行线性变换就可以了,大部分情况下,为了更好地明确结果,还需要sigmoid或者softmax进行处理,这种结构通常被用在文本分类上。注意,这里的h_4

  图4:1 vs N - RNN图示 

当输入的不是序列而输出的是序列时,我们需要将输入作用在每次输出之上,这种结构可以用于将图片生成文字任务。


4. N vs M - RNN

 图5:N vs M - RNN图示 

这是一种不限制输入输出长度的RNN结构,它由编码器和解码器两部分组成,可以理解为编码器部分是N vs 1 - RNN,解码部分为1 vs N - RNN,这类RNN就被理解为Seq2Seq架构,输入数据首先经过编码器,最终输出一个隐含变量c

该模型接受句子作为输入,使用Encoder来生成向量,再用Decoder来生成目标语句。将上面两个模块结合起来就得到了整体的模型。在序列到序列处理不定长序列的过程中,采用了序列的起始标志和终止标志来“告诉”编码器的编码过程何时开始与结束,也就是间接反映了当前序列的长度信息。这里将结合到我们后面分析的vocabulary构建字典,对于语料库中的所有词,构建一个字典,实现word2idx和idx2word方法,对于语句的开始和结尾,用填充,对于没有出现在字典中的词,用填充,然后转为idx,再用模型处理。


三、代码分析

class Seq2Seq(nn.Module):def __init__(self, encoder, decoder, device):super().__init__()self.encoder = encoderself.decoder = decoderself.device = deviceassert encoder.hid_dim == decoder.hid_dim, \"Hidden dimensions of encoder and decoder must be equal!"def forward(self, src, trg, teacher_forcing_ratio = 0.5):#src = [src len, batch size]#trg = [trg len, batch size]#teacher_forcing_ratio is probability to use teacher forcing#e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the timebatch_size = trg.shape[1]trg_len = trg.shape[0]trg_vocab_size = self.decoder.output_dim#存储decoder输出的张量outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)#encoder最后的隐藏层context = self.encoder(src)#encoder最后的隐藏层作为decoder初始的隐藏层hidden = context#首先输入decoder的是input = trg[0,:]for t in range(1, trg_len):#插入嵌入的输入标记,之前的隐藏层状态和context#接收输出张量和新的隐藏层状态output, hidden = self.decoder(input, hidden, context)#将预测放在一个张量中,该张量有每个token的预测outputs[t] = output#是否使用teaching_forceteacher_force = random.random()

将源序列x馈入编码器以接收上下文张量,创建输出张量保留所有预测。使用一批令牌作为第一个输入𝑦1,然后,我们在一个循环中解码,将输入令牌𝑦𝑡,先前的隐藏状态𝑠𝑡−1和上下文向量插入解码器,接收预测𝑡+ 1以及新的隐藏状态𝑠𝑡。

训练模型:

def train(model, iterator, optimizer, criterion, clip):model.train()epoch_loss = 0for i, batch in enumerate(iterator):# 从批处理中后去源句子和目标句子X,Ysrc = batch.srctrg = batch.trg# 将最后一批计算出的梯度归零optimizer.zero_grad()# 将源和目标放入模型中输出下一个youtput = model(src, trg)#trg = [trg len, batch size]#output = [trg len, batch size, output dim]output_dim = output.shape[-1]# 由于损失函数是仅适用于具有1d目标的2d输入,因此我们用view将其展平输入# 将输出张量和目标张量的第一列切开output = output[1:].view(-1, output_dim)trg = trg[1:].view(-1)#trg = [(trg len - 1) * batch size]#output = [(trg len - 1) * batch size, output dim]loss = criterion(output, trg)# 用此函数计算梯度loss.backward()# clip the gradients 防止其梯度爆炸torch.nn.utils.clip_grad_norm_(model.parameters(), clip)# 执行优化程序步骤来更新模型的参数optimizer.step()# 损失值求和epoch_loss += loss.item()# 返回所有批次的平均损失 return epoch_loss / len(iterator)

从批处理中后去源句子和目标句子X,Y,将最后一批计算出的梯度归零,将源和目标放入模型中输出下一个y,计算梯度同时要防止梯度爆炸,损失值求和最后返回所有批次的平均损失。

评估部分代码:

def evaluate(model, iterator, criterion):# 设置为评估模式,关闭Dropout(弱使用批处理规范化,则也关闭)model.eval()epoch_loss = 0with torch.no_grad():# 确保该模块内不计算梯度for i, batch in enumerate(iterator):src = batch.srctrg = batch.trg# 必须确保关闭teacher_forcing 参数output = model(src, trg, 0) #turn off teacher forcing#trg = [trg len, batch size]#output = [trg len, batch size, output dim]output_dim = output.shape[-1]output = output[1:].view(-1, output_dim)trg = trg[1:].view(-1)#trg = [(trg len - 1) * batch size]#output = [(trg len - 1) * batch size, output dim]loss = criterion(output, trg)epoch_loss += loss.item()return epoch_loss / len(iterator)

计算运行时间:

def epoch_time(start_time, end_time):elapsed_time = end_time - start_timeelapsed_mins = int(elapsed_time / 60)elapsed_secs = int(elapsed_time - (elapsed_mins * 60))return elapsed_mins, elapsed_secsN_EPOCHS = 10
CLIP = 1best_valid_loss = float('inf')for epoch in range(N_EPOCHS):start_time = time.time()train_loss = train(model, train_iterator, optimizer, criterion, CLIP)valid_loss = evaluate(model, valid_iterator, criterion)end_time = time.time()epoch_mins, epoch_secs = epoch_time(start_time, end_time)if valid_loss model.load_state_dict(torch.load('tut1-model.pt'))test_loss = evaluate(model, test_iterator, criterion)print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')


四、Attention模型

传统的Encoder-Decoder是有很多弊端的,后来又提出了Attention模型,这种模型在产生输出的时候,还会产生一个“注意力范围”表示接下来输出的时候要重点关注输入序列中的哪些部分,然后根据关注的区域来产生下一个输出,如此往复。模型的大概示意图如下所示。

目前的attention模型可大致分为两类:


  • 聚焦式(focus)注意力:自上而下的有意识的注意力,主动注意——是指有预定目的、依赖任务的、主动有意识地聚焦于某一对象的注意力;
  • 显著性(saliency-based)注意力:自下而上的有意识的注意力,被动注意——基于显著性的注意力是由外界刺激驱动的注意,不需要主动干预,也和任务无关;


推荐阅读
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • CSS3选择器的使用方法详解,提高Web开发效率和精准度
    本文详细介绍了CSS3新增的选择器方法,包括属性选择器的使用。通过CSS3选择器,可以提高Web开发的效率和精准度,使得查找元素更加方便和快捷。同时,本文还对属性选择器的各种用法进行了详细解释,并给出了相应的代码示例。通过学习本文,读者可以更好地掌握CSS3选择器的使用方法,提升自己的Web开发能力。 ... [详细]
  • 本文由编程笔记#小编为大家整理,主要介绍了logistic回归(线性和非线性)相关的知识,包括线性logistic回归的代码和数据集的分布情况。希望对你有一定的参考价值。 ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • sklearn数据集库中的常用数据集类型介绍
    本文介绍了sklearn数据集库中常用的数据集类型,包括玩具数据集和样本生成器。其中详细介绍了波士顿房价数据集,包含了波士顿506处房屋的13种不同特征以及房屋价格,适用于回归任务。 ... [详细]
  • Iamtryingtomakeaclassthatwillreadatextfileofnamesintoanarray,thenreturnthatarra ... [详细]
  • Spring源码解密之默认标签的解析方式分析
    本文分析了Spring源码解密中默认标签的解析方式。通过对命名空间的判断,区分默认命名空间和自定义命名空间,并采用不同的解析方式。其中,bean标签的解析最为复杂和重要。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • Linux重启网络命令实例及关机和重启示例教程
    本文介绍了Linux系统中重启网络命令的实例,以及使用不同方式关机和重启系统的示例教程。包括使用图形界面和控制台访问系统的方法,以及使用shutdown命令进行系统关机和重启的句法和用法。 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • [译]技术公司十年经验的职场生涯回顾
    本文是一位在技术公司工作十年的职场人士对自己职业生涯的总结回顾。她的职业规划与众不同,令人深思又有趣。其中涉及到的内容有机器学习、创新创业以及引用了女性主义者在TED演讲中的部分讲义。文章表达了对职业生涯的愿望和希望,认为人类有能力不断改善自己。 ... [详细]
  • 《数据结构》学习笔记3——串匹配算法性能评估
    本文主要讨论串匹配算法的性能评估,包括模式匹配、字符种类数量、算法复杂度等内容。通过借助C++中的头文件和库,可以实现对串的匹配操作。其中蛮力算法的复杂度为O(m*n),通过随机取出长度为m的子串作为模式P,在文本T中进行匹配,统计平均复杂度。对于成功和失败的匹配分别进行测试,分析其平均复杂度。详情请参考相关学习资源。 ... [详细]
  • 拥抱Android Design Support Library新变化(导航视图、悬浮ActionBar)
    转载请注明明桑AndroidAndroid5.0Loollipop作为Android最重要的版本之一,为我们带来了全新的界面风格和设计语言。看起来很受欢迎࿰ ... [详细]
  • 自动轮播,反转播放的ViewPagerAdapter的使用方法和效果展示
    本文介绍了如何使用自动轮播、反转播放的ViewPagerAdapter,并展示了其效果。该ViewPagerAdapter支持无限循环、触摸暂停、切换缩放等功能。同时提供了使用GIF.gif的示例和github地址。通过LoopFragmentPagerAdapter类的getActualCount、getActualItem和getActualPagerTitle方法可以实现自定义的循环效果和标题展示。 ... [详细]
  • 本文详细介绍了Java中vector的使用方法和相关知识,包括vector类的功能、构造方法和使用注意事项。通过使用vector类,可以方便地实现动态数组的功能,并且可以随意插入不同类型的对象,进行查找、插入和删除操作。这篇文章对于需要频繁进行查找、插入和删除操作的情况下,使用vector类是一个很好的选择。 ... [详细]
author-avatar
伤不起饼子_132
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有