作者:lzmhezy198344 | 来源:互联网 | 2022-12-13 20:36
我正在关注Pytorch seq2seq教程,其使用torch.bmm
方法如下:
attn_applied = torch.bmm(attn_weights.unsqueeze(0),
encoder_outputs.unsqueeze(0))
我了解为什么我们需要将注意力权重和编码器输出相乘。
我不太明白的是我们bmm
在这里需要方法的原因。
torch.bmm
文件说
对批处理1和批处理2中存储的矩阵执行批处理矩阵矩阵乘积。
batch1和batch2必须是每个都包含相同数量矩阵的3-D张量。
如果batch1是(b×n×m)张量,batch2是(b×m×p)张量,out将是(b×n×p)张量。
1> Wasi Ahmad..:
在seq2seq模型中,编码器对以小批量形式给出的输入序列进行编码。假设输入为B x S x d
,其中B是批处理大小,S是最大序列长度,d是单词嵌入维数。然后,编码器的输出为:B x S x h
其中h是编码器(RNN)的隐藏状态大小。
现在,在解码时(在训练过程中)
,输入序列一次被赋予一个,因此输入为B x 1 x d
,解码器产生shape的张量B x 1 x h
。现在要计算上下文向量,我们需要将解码器的隐藏状态与编码器的编码状态进行比较。
因此,考虑您有两个形状为T1 = B x S x h
和的张量T2 = B x 1 x h
。因此,如果可以按如下所示进行批矩阵乘法。
out = torch.bmm(T1, T2.transpose(1, 2))
本质上,您是将一个形状B x S x h
的张量与一个形状的张量相乘B x h x 1
,这将导致B x S x 1
每个批次的注意权重。
这里,注意力权重B x S x 1
表示解码器的当前隐藏状态与编码器的所有隐藏状态之间的相似度得分。现在,您可以B x S x h
通过首先换位来吸引注意力权重,使其与编码器的隐藏状态相乘,这将导致shape的张量B x h x 1
。而且,如果在dim = 2时执行挤压,将得到形状的张量,B x h
这是您的上下文向量。
此上下文向量(B x h
)通常与解码器的隐藏状态(B x 1 x h
,挤压dim = 1)相连,以预测下一个标记。