热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

每天三分钟之Pytorch编程5:RNN新对手TCN(4)

本公众号由哥大,清华,复旦,中科大多名研究生共同创建,主要更新最新AI顶会快速解读和Pytorch编程,金融大数据与量化投资。如果你喜欢的话,请关注我们公众号,有学习资源放送,谢谢

本公众号由哥大,清华,复旦,中科大多名研究生共同创建,主要更新最新AI顶会快速解读和Pytorch编程,金融大数据与量化投资。如果你喜欢的话,请关注我们公众号,有学习资源放送,谢谢!

http://weixin.qq.com/r/Bik6IiTElnnprWBB93wU (二维码自动识别)

欢迎加微信号:uft-uft进群交流,记得备注知乎+所在学校或企业。

https://u.wechat.com/MK2C3qQoAkXkuiSqn2DomBc (二维码自动识别)

《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》

TCN是指时间卷积网络,一种新型的可以用来解决时间序列预测的算法。在上三次更新中已经做了初步介绍以及数据预处理部分的代码实现。

每天三分钟之Pytorch编程-5:RNN的新对手TCN(1)

每天三分钟之Pytorch编程-5:RNN的新对手TCN(2)

每天三分钟之Pytorch编程-5:RNN的新对手TCN(3)

论文名称:

An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling

作者:Shaojie Bai 1 J. Zico Kolter 2 Vladlen Koltun 3

1

模型展示

为了方便解读代码,先将模型的图片展示如下:

《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》
《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》
《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》
《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》
《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》
《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》

2

代码详解

下面是代码详解部分

实现因果卷积的类

《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》
《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》

残差模块

《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》
《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》
《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》
《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》

时间卷积网络的架构

《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》
《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》

TCN

《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》
《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》
《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》
《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》

具体代码

# 定义实现因果卷积的类(继承自类nn.Module)

class Chomp1d(nn.Module):

#继承自类nn.Module

def __init__(self, chomp_size):

super(Chomp1d, self).__init__()

#表示对继承自父类的属性进行初始化。

self.chomp_size = chomp_size

def forward(self, x):

return x[:, :, :-self.chomp_size].contiguous()

# tensor.contiguous()会返回有连续内存的相同张量

#有些tensor并不是占用一整块内存,而是由不同的数据块组成

#而tensor的view()操作依赖于内存是整块的,这时只需要执行

#contiguous()这个函数,就能把tensor变成在内存中连续分布的形式。

# 通过增加Padding方式对卷积后的张量做切片而实现因果卷积

# 残差模块,其中有两个一维卷积与恒等映射,具体结构可看图片

class TemporalBlock(nn.Module):

def __init__(self, n_inputs, n_outputs, kernel_size,

stride, dilation, padding, dropout=0.2):

super(TemporalBlock, self).__init__()

self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs,

kernel_size,

stride=stride,

padding=padding,

dilation=dilation))

#定义第一个扩散卷积层,扩散是指dilation=dilation

self.chomp1 = Chomp1d()

# 根据第一个卷积层的输出与padding大小实现因果卷积

self.relu1 = nn.ReLU()

self.dropout1 = nn.Dropout2d(dropout)

# 在先前的输出结果上添加激活函数与dropout完成第一个卷积

self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs,

kernel_size,

stride=stride,

padding=padding,

dilation=dilation))

self.chomp2 = Chomp1d(padding)

# padding保证了输入序列与输出序列的长度相等,

#但卷积前的通道数与卷积后的通道数不一定一样。

self.relu2 = nn.ReLU()

self.dropout2 = nn.Dropout2d(dropout)

#以上四行是与第一个卷积层堆叠了同样结构的第二个卷积层

http://self.net = nn.Sequential(self.conv1, self.chomp1,

self.relu1, self.dropout1,

self.conv2, self.chomp2,

self.relu2, self.dropout2)

# 将卷积模块的所有组建通过Sequential方法依次堆叠在一起

#具体来说的话网络结构是一层一层的叠加起来的,nn库里有一个类型:

#叫做Sequential序列,它是一个容器类,可以在里面添加一些基本的模块。

self.downsample=nn.Conv1d(n_inputs,n_outputs,1)if n_inputs!=n_outputs else None

self.relu = nn.ReLU()

self.init_weights()

#正如先前提到的卷积前与卷积后的通道数不一定相同

#所以如果通道数不一样,那么需要对输入x做一个逐元素的一维卷积

#以使得它的维度与前面两个卷积相等。

def init_weights(self):

self.conv1.weight.data.normal_(0, 0.01)

self.conv2.weight.data.normal_(0, 0.01)

#初始化方法是从均值为0,标准差为0.01的正态分布采样

if self.downsample is not None:

self.downsample.weight.data.normal_(0, 0.01)

def forward(self, x):

out = http://self.net(x)#输入的逐元素卷积与relu激活函数

res = x if self.downsample is None else self.downsample(x)

#残差模块

return self.relu(out + res)

# 定义时间卷积网络的架构

class TemporalConvNet(nn.Module):

def __init__(self, num_inputs, num_channels,

kernel_size=2, dropout=0.2):

super(TemporalConvNet, self).__init__()

layers = []

#num_channels为各层卷积运算的输出通道数或卷积核数量

#num_channels的长度即需要执行的卷积层数量

num_levels = len(num_channels)

# 扩张系数若能随着网络层级的增加而成指数增加,

#则可以增大感受野并不丢弃任何输入序列的元素

for i in range(num_levels):

dilation_size = 2 ** i

#dilation_size根据层级数指数增加

in_channels = num_inputs if i == 0 else num_channels[i – 1]

out_channels = num_channels[i]

#从num_channels中抽取每一个残差模块的输入通道数与输出通道数

layers += [TemporalBlock(in_channels, out_channels,

kernel_size, stride=1,

dilation=dilation_size,

padding=(kernel_size – 1) * dilation_size,

dropout=dropout)]

# 将所有残差模块堆叠起来组成一个深度卷积网络

self.network = nn.Sequential(*layers)

def forward(self, x):

return self.network(x)

class TCN(nn.Module):

def __init__(self, input_size, output_size, num_channels,

kernel_size=2, dropout=0.3, emb_dropout=0.1,

tied_weights=False):

super(TCN, self).__init__()

self.encoder = nn.Embedding(output_size, input_size)

#将one-hot encoding 部分送入编码器作为一个批量的词嵌入向量

#output_size为词汇量,input_size是词向量的长度

self.tcn = TemporalConvNet(input_size, num_channels,

kernel_size, dropout=dropout)

#时间卷积网络的架构

self.decoder = nn.Linear(num_channels[-1], output_size)

# 定义最后线性变换的维度,即最后一个卷积层的通道数到所有词汇的映射

if tied_weights:

if num_channels[-1] != input_size:

raise ValueError(‘When using the tied flag, nhid must be equal to emsize’)

self.decoder.weight = self.encoder.weight

print(“Weight tied”)

#是否共享编码器与解码器的权重,默认值为共享

#共享时需要保持隐藏单元数等于词嵌入的长度

#此时将预测的向量认为是词嵌入向量

self.drop = nn.Dropout(emb_dropout)

self.emb_dropout = emb_dropout

#对输入词嵌入进行dropout

self.init_weights()

def init_weights(self):

self.encoder.weight.data.normal_(0, 0.01)

self.decoder.bias.data.fill_(0)

self.decoder.weight.data.normal_(0, 0.01)

#进行初始化

def forward(self, input):

“””Input ought to have dimension (N, C_in, L_in),

where L_in is the seq_len; here the input is (N, L, C)”””

emb = self.drop(self.encoder(input))#编码并进行dropout

y = self.tcn(emb.transpose(1, 2)).transpose(1, 2)

#输入到网络进行推断

y = self.decoder(y)

#将推断结果解码为词

return y.contiguous()

往期文章精选

AI顶会论文快速解读|将金融领域中的CVaR作为目标策略来定制对话模型

AI顶会论文快速解读|上下文对比特征与门控多尺度聚合用于场景分割

AI顶会论文详细解读|对抗训练解决开放式生成式对话

AI顶会论文详细解读|深度强化学习之基于对话交互的学习对话

AI顶会论文快速解读|控制具体化程度的闲聊对话系统

精美PPT快速熟悉RNN与LSTM,附tensorflow教程代码(回复PPT有资源放送)

每天三分钟之Pytorch编程-1:为何选择你?

每天三分钟之Pytorch编程-2:一起来搭积木吧

每天三分钟之Pytorch编程-3:没事就来炼丹吧

每天三分钟之Pytorch编程-4:来搭建个翻译系统吧(1)

每天三分钟之Pytorch编程-4:来搭建个翻译系统吧(2)

每天三分钟之Pytorch编程-4:来搭建个翻译系统吧(3)

关注本公众号

http://weixin.qq.com/r/Bik6IiTElnnprWBB93wU (二维码自动识别)

《每天三分钟之Pytorch编程-5:RNN新对手TCN(4)》


推荐阅读
  • Learning to Paint with Model-based Deep Reinforcement Learning
    本文介绍了一种基于模型的深度强化学习方法,通过结合神经渲染器,教机器像人类画家一样进行绘画。该方法能够生成笔画的坐标点、半径、透明度、颜色值等,以生成类似于给定目标图像的绘画。文章还讨论了该方法面临的挑战,包括绘制纹理丰富的图像等。通过对比实验的结果,作者证明了基于模型的深度强化学习方法相对于基于模型的DDPG和模型无关的DDPG方法的优势。该研究对于深度强化学习在绘画领域的应用具有重要意义。 ... [详细]
  • 模板引擎StringTemplate的使用方法和特点
    本文介绍了模板引擎StringTemplate的使用方法和特点,包括强制Model和View的分离、Lazy-Evaluation、Recursive enable等。同时,还介绍了StringTemplate语法中的属性和普通字符的使用方法,并提供了向模板填充属性的示例代码。 ... [详细]
  • 判断编码是否可立即解码的程序及电话号码一致性判断程序
    本文介绍了两个编程题目,一个是判断编码是否可立即解码的程序,另一个是判断电话号码一致性的程序。对于第一个题目,给出一组二进制编码,判断是否存在一个编码是另一个编码的前缀,如果不存在则称为可立即解码的编码。对于第二个题目,给出一些电话号码,判断是否存在一个号码是另一个号码的前缀,如果不存在则说明这些号码是一致的。两个题目的解法类似,都使用了树的数据结构来实现。 ... [详细]
  • ①页面初始化----------收到客户端的请求,产生相应页面的Page对象,通过Page_Init事件进行page对象及其控件的初始化.②加载视图状态-------ViewSta ... [详细]
  • 【论文】ICLR 2020 九篇满分论文!!!
    点击上方,选择星标或置顶,每天给你送干货!阅读大概需要11分钟跟随小博主,每天进步一丢丢来自:深度学习技术前沿 ... [详细]
  • 本文介绍了使用kotlin实现动画效果的方法,包括上下移动、放大缩小、旋转等功能。通过代码示例演示了如何使用ObjectAnimator和AnimatorSet来实现动画效果,并提供了实现抖动效果的代码。同时还介绍了如何使用translationY和translationX来实现上下和左右移动的效果。最后还提供了一个anim_small.xml文件的代码示例,可以用来实现放大缩小的效果。 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 本文介绍了在Linux下安装Perl的步骤,并提供了一个简单的Perl程序示例。同时,还展示了运行该程序的结果。 ... [详细]
  • 也就是|小窗_卷积的特征提取与参数计算
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了卷积的特征提取与参数计算相关的知识,希望对你有一定的参考价值。Dense和Conv2D根本区别在于,Den ... [详细]
  • 本文介绍了一些Java开发项目管理工具及其配置教程,包括团队协同工具worktil,版本管理工具GitLab,自动化构建工具Jenkins,项目管理工具Maven和Maven私服Nexus,以及Mybatis的安装和代码自动生成工具。提供了相关链接供读者参考。 ... [详细]
  • Hibernate延迟加载深入分析-集合属性的延迟加载策略
    本文深入分析了Hibernate延迟加载的机制,特别是集合属性的延迟加载策略。通过延迟加载,可以降低系统的内存开销,提高Hibernate的运行性能。对于集合属性,推荐使用延迟加载策略,即在系统需要使用集合属性时才从数据库装载关联的数据,避免一次加载所有集合属性导致性能下降。 ... [详细]
  • 第8章 使用外部和内部链接
    8.1使用web地址LearnAboutafricanelephants. ... [详细]
  • angular.element使用方法及总结
    2019独角兽企业重金招聘Python工程师标准在线查询:http:each.sinaapp.comangularapielement.html使用方法 ... [详细]
  • Vue基础一、什么是Vue1.1概念Vue(读音vjuː,类似于view)是一套用于构建用户界面的渐进式JavaScript框架,与其它大型框架不 ... [详细]
  • 获取时间的函数js代码,js获取时区代码
    本文目录一览:1、js获取服务器时间(动态)2 ... [详细]
author-avatar
侯faulds_534
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有