热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

VisionTransformer原理及代码实战

VisionTransformer原理及代码实战背景

Vision Transformer原理及代码实战

背景

论文地址:https://arxiv.org/pdf/2010.11929.pdf

代码参考:https://github.com/BR-IDL/PaddleViT

在NLP领域,Transformer深度学习技术已经"统治"了该领域;

在CV领域,从2020年底开始,Vision Transformer(ViT)成为该方向的研究热点;基于Transformer的模型在多个视觉任务中已经超越CNN模型达到SOTA性能的程度;

Transformer概念引入

Transform一开始是出现在NLP领域中,下面看一个翻译的实际应用:

在这里插入图片描述

主要实现步骤为:

输入文本 —— 分词 —— Transformer模型 —— 输出结果

实际上Encoders和Decoders代表的是多个的组成,类似于卷积网络的堆叠;

NLP中单独的Encoder和Decoder的具体实现如下:

在这里插入图片描述

其中的MSA和FFN结构在后续的代码实战中会进行讲解;

Vision Transformer引入

受到NLP领域中Transforms成功应用的启发,ViT算法中尝试将标准的Transformer结构直接应用图图像中,实现流程如下:

1、将整个图像拆分成小图像块;

2、将小图像块映射成线性嵌入序列;

3、将线性嵌入序列传入网络中实现任务;

在这里插入图片描述

其中最重要的步骤为Patch Embedding和Encoder,暂时没用到Decoder;

在中规模和大规模的数据集下,作者验证得到以下结论:

1、Transformer对比CNN结构,缺少一定的平移不变性和局部感知性,因此在数据量不够大时,很难达到CNN的同等效果;也就是说在中规模数据集下效果会比CNN的低上几个百分点;

2、当具有大量训练样本时,可使用大规模数据集训练后,再使用迁移学习的方式应用到其他数据集上,此时Transformer可以超越或达到SOTA的水平;

Patch Embedding原理

Patch Embedding又称为图像分块嵌入,Transformer结构中,需要输入的是一个二维矩阵(S,D),其中S是sequence的长度,D是sequcence中每个向量的维度,因此需要将三维的图像矩阵转换为二维的矩阵;

ViT中具体的实现方式为,将HWC的图像变成一个S x (P²*C)的序列;其中P代表图像块的边长,C代表通道数,N则表示图像块的个数(WH/P²);由于最终需要的向量维度为D,需要再做一个Embedding的操作,对(P² * C)的图像块做一个线性变化压缩为D即可;

Embedding的定义:高维空间向低维空间的映射;

在这里插入图片描述

上面的Patch Embedding也可以通过卷积滑窗来实现(也就是卷积实现)

Attention注意力机制原理

Attention在论文中是这么解释的:在单个序列中使用不同位置的注意力用于实现该序列的表征方法;

最重要的就是提出了query - key - value思想,当时的该模型聚焦的任务主要是question answering,先用输入的问题query检索key-value memories,找到和问题相似的memory的key,计算相关性分数,然后对value embedding进行加权求和,得到一个输出向量,慢慢就衍生了Attention中的qkv;

QKV是输入的X乘上Wq, Wk, Wv三个矩阵得到的;

Self Attention的计算图:

在这里插入图片描述

结构逻辑图:

在这里插入图片描述

代码实现

1、首先实现一下Patch Embedding结构;

class PatchEmbedding(nn.Layer):
def __init__(self, image_size, patch_size, in_channels, embed_dim, dropout=0.):
super().__init__()
# embedding本质是一个卷积操作
self.patch_embedding = nn.Conv2D(in_channels,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias_attr=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# x 原来为[n, c, h, w]
x = self.patch_embedding(x) # 经过卷积操作后:[n, c', h', w'],c'是我们所需要的维度
x = x.flatten(2) # 将2、3维度合并:[n, c', h'*w']
x = x.transpose([0, 2, 1]) # 维度转换:[n, h'*w', c']
x = self.dropout(x)
return x

2、实现一个MLP的结构

MLP实际上就是两层全连接,并且经过MLP后维度不发生改变;

class Mlp(nn.Layer):
def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.):
super().__init__()
# 两层全连接层
self.fc1 = nn.Linear(embed_dim, int(embed_dim * mlp_ratio))
self.fc2 = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
# GELU的激活函数
self.act = nn.GELU()
# dropout层
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x

3、实现一个Encoder层

class EncoderLayer(nn.Layer):
def __init__(self, embed_dim):
super().__init__()
# 做特征归一化操作
self.attn_norm = nn.LayerNorm(embed_dim)
# Attention层在之后进行实现
self.attn = Attention()
self.mlp_norm = nn.LayerNorm(embed_dim)
# 之前实现的MLP结构
self.mlp = Mlp(embed_dim)
def forward(self, x):
# 这里也有用到残杀结构
h = x
x = self.attn_norm(x)
x = self.attn(x)
x = x + h # 维度不变,可直接相加
h = x
x = self.mlp_norm(x)
x = self.mlp(x)
x = x + h
return x

4、Attention代码实现

class Attention(nn.Layer):
def __init__(self, embed_dim, num_heads,
qkv_bias=False, qk_scale=None, dropout=0., attention_dropout=0.):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = int(embed_dim / num_heads)
self.all_head_dim = self.head_dim * num_heads
self.qkv = nn.Linear(embed_dim,
self.all_head_dim * 3,
bias_attr=False if qkv_bias is False else None)
self.scale = self.head_dim ** -0.5 if qk_scale is None else qk_scale
self.dropout = nn.Dropout(dropout)
self.attention_dropout = nn.Dropout(attention_dropout)
self.proj = nn.Linear(self.all_head_dim, embed_dim)
self.softmax = nn.Softmax(-1)
def transpose_multi_head(self, x):
# x:[n, num_patches, all_head_dim]
new_shape = x.shape[:-1] + [self.num_heads, self.head_dim]
x = x.reshape(new_shape)
# x:[n, num_patches, num_heads, head_dim]
x = x.transpose([0, 2, 1, 3])
# x:[n, num_heads, num_patches, head_dim]
return x
def forward(self, x):
B, N, _ = x.shape
# x: [n, num_patches, embed_dim]
qkv = self.qkv(x).chunk(3, -1)
# qkv: [n, num_patches, all_head_dim] * 3
q, k, v = map(self.transpose_multi_head, qkv)
# q, k, v:[n, num_heads, num_patches, head_dim]
attn = paddle.matmul(q, k, transpose_y=True)
attn = self.scale * attn
attn = self.softmax(attn)
attn_weights = attn
attn = self.attention_dropout(attn)
# attn: [n, num_heads, num_patches, num_patches]

out = paddle.matmul(attn, v)
# out: [n, num_heads, num_patches, head_dim]
out = out.transpose([0, 2, 1, 3])
# out: [n, num_patches, num_heads, head_dim]
out = out.reshape([B, N, -1])
out = self.proj(out)
out = self.dropout(out)
return out, attn_weights

由于当前对Attention理解还不够透彻,先把代码粘贴在这,便于之后回顾;

5、实现ViT结构,将之前实现的结构串联到一起

class ViT(nn.Layer):
def __init__(self):
super().__init__()
# 定义Patch Embedding结构
self.patch_embed = PatchEmbedding(224, 7, 3, 16)
# 定义Encoder层
layer_list = [EncoderLayer(16) for i in range(5)]
self.encoders = nn.LayerList(layer_list)
# 定义全连接层实现分类
self.head = nn.Linear(16, 10)
self.avgpool = nn.AdaptiveAvgPool1D(1)
self.norm = nn.LayerNorm(16)
def forward(self, x):
# 第一步经过Patch Embedding(图像分块)
x = self.patch_embed(x) # [n, h*w, c]: 4, 1024, 16
# 第二步进入Transformer层,也就是五层Encoder
for encoder in self.encoders:
x = encoder(x)
x = self.norm(x)
# 进行维度转换
x = x.transpose([0, 2, 1])
# 将所有batch合并起来
x = self.avgpool(x)
x = x.flatten(1)
# 进行分类,输出对应类别的向量
x = self.head(x)
return x
# 用一个主程序进行验证
if __name__ == "__main__":
t = paddle.randn([4, 3, 224, 224])
model = ViT()
out = model(t)
print(out.shape) # 输出[4, 10]

总结

在ViT中我们运用的是LN的标准化处理,而对比BN有什么区别呢,可以参考下面这篇文章:

参考文章:https://www.cnblogs.com/gczr/p/12597344.html

Paddle中还有一个小技巧,就是用paddle.summary可以打印模型的数据流:

paddle.summary(vit, (4, 3, 224, 224)) # must be tuple

打印结果如下图所示:

在这里插入图片描述

可以看出每一层的名称,对应的input和output,以及所占用的参数数量;

最后,ViT属于当前比较前沿的技术点,往往对大型数据集有比较好的效果,实际在工作中接触到的数据集没有那么大,加入ViT的结构可能没有很好的效果,反而会影响速度(毕竟有多个Linner层),了解前沿的技术还是有助于我们对网络的选择以及修改的,多学没有坏处!

在这里插入图片描述


推荐阅读
  • 本文介绍了Android中的assets目录和raw目录的共同点和区别,包括获取资源的方法、目录结构的限制以及列出资源的能力。同时,还解释了raw目录中资源文件生成的ID,并说明了这些目录的使用方法。 ... [详细]
  • 本文介绍了响应式页面的概念和实现方式,包括针对不同终端制作特定页面和制作一个页面适应不同终端的显示。分析了两种实现方式的优缺点,提出了选择方案的建议。同时,对于响应式页面的需求和背景进行了讨论,解释了为什么需要响应式页面。 ... [详细]
  • 本文介绍了H5游戏性能优化和调试技巧,包括从问题表象出发进行优化、排除外部问题导致的卡顿、帧率设定、减少drawcall的方法、UI优化和图集渲染等八个理念。对于游戏程序员来说,解决游戏性能问题是一个关键的任务,本文提供了一些有用的参考价值。摘要长度为183字。 ... [详细]
  • 如何自行分析定位SAP BSP错误
    The“BSPtag”Imentionedintheblogtitlemeansforexamplethetagchtmlb:configCelleratorbelowwhichi ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
  • XML介绍与使用的概述及标签规则
    本文介绍了XML的基本概念和用途,包括XML的可扩展性和标签的自定义特性。同时还详细解释了XML标签的规则,包括标签的尖括号和合法标识符的组成,标签必须成对出现的原则以及特殊标签的使用方法。通过本文的阅读,读者可以对XML的基本知识有一个全面的了解。 ... [详细]
  • 个人学习使用:谨慎参考1Client类importcom.thoughtworks.gauge.Step;importcom.thoughtworks.gauge.T ... [详细]
  • HTML学习02 图像标签的使用和属性
    本文介绍了HTML中图像标签的使用和属性,包括定义图像、定义图像地图、使用源属性和替换文本属性。同时提供了相关实例和注意事项,帮助读者更好地理解和应用图像标签。 ... [详细]
  • [大整数乘法] java代码实现
    本文介绍了使用java代码实现大整数乘法的过程,同时也涉及到大整数加法和大整数减法的计算方法。通过分治算法来提高计算效率,并对算法的时间复杂度进行了研究。详细代码实现请参考文章链接。 ... [详细]
  • 本文介绍了南邮ctf-web的writeup,包括签到题和md5 collision。在CTF比赛和渗透测试中,可以通过查看源代码、代码注释、页面隐藏元素、超链接和HTTP响应头部来寻找flag或提示信息。利用PHP弱类型,可以发现md5('QNKCDZO')='0e830400451993494058024219903391'和md5('240610708')='0e462097431906509019562988736854'。 ... [详细]
  • Redis底层数据结构之压缩列表的介绍及实现原理
    本文介绍了Redis底层数据结构之压缩列表的概念、实现原理以及使用场景。压缩列表是Redis为了节约内存而开发的一种顺序数据结构,由特殊编码的连续内存块组成。文章详细解释了压缩列表的构成和各个属性的含义,以及如何通过指针来计算表尾节点的地址。压缩列表适用于列表键和哈希键中只包含少量小整数值和短字符串的情况。通过使用压缩列表,可以有效减少内存占用,提升Redis的性能。 ... [详细]
  • 第四章高阶函数(参数传递、高阶函数、lambda表达式)(python进阶)的讲解和应用
    本文主要讲解了第四章高阶函数(参数传递、高阶函数、lambda表达式)的相关知识,包括函数参数传递机制和赋值机制、引用传递的概念和应用、默认参数的定义和使用等内容。同时介绍了高阶函数和lambda表达式的概念,并给出了一些实例代码进行演示。对于想要进一步提升python编程能力的读者来说,本文将是一个不错的学习资料。 ... [详细]
  • 网址:https:vue.docschina.orgv2guideforms.html表单input绑定基础用法可以通过使用v-model指令,在 ... [详细]
  • Windows7企业版怎样存储安全新功能详解
    本文介绍了电脑公司发布的GHOST WIN7 SP1 X64 通用特别版 V2019.12,软件大小为5.71 GB,支持简体中文,属于国产软件,免费使用。文章还提到了用户评分和软件分类为Win7系统,运行环境为Windows。同时,文章还介绍了平台检测结果,无插件,通过了360、腾讯、金山和瑞星的检测。此外,文章还提到了本地下载文件大小为5.71 GB,需要先下载高速下载器才能进行高速下载。最后,文章详细解释了Windows7企业版的存储安全新功能。 ... [详细]
author-avatar
王小贱
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有