↑↑↑关注后"星标"Datawhale
每日干货 & 每月组队学习,不错过
Datawhale干货
作者:安晟&闫永强,Datawhale成员
本篇正文部分约10000字,分模块解读并实践了Transformer,建议收藏阅读。
2017年谷歌在一篇名为《Attention Is All You Need》的论文中,提出了一个基于attention(自注意力机制)结构来处理序列相关的问题的模型,名为Transformer。
Transformer在很多不同nlp任务中获得了成功,例如:文本分类、机器翻译、阅读理解等。在解决这类问题时,Transformer模型摒弃了固有的定式,并没有用任何CNN或者RNN的结构,而是使用了Attention注意力机制,自动捕捉输入序列不同位置处的相对关联,善于处理较长文本,并且该模型可以高度并行地工作,训练速度很快。
本文将按照Transformer的模块进行讲解,每个模块配合代码+注释+讲解
来介绍,最后会有一个玩具级别的序列预测任务进行实战。
通过本文,希望可以帮助大家,初探Transformer的原理和用法,下面直接进入正式内容:
1 模型结构概览
如下是Transformer的两个结构示意图:
上图是从一篇英文博客中截取的Transformer的结构简图,下图是原论文中给出的结构简图,更细粒度一些,可以结合着来看。
模型大致分为Encoder
(编码器)和Decoder
(解码器)两个部分,分别对应上图中的左右两部分。
其中编码器由N个相同的层堆叠在一起(我们后面的实验取N=6),每一层又有两个子层。
第一个子层是一个Multi-Head Attention
(多头的自注意机制),第二个子层是一个简单的Feed Forward
(全连接前馈网络)。两个子层都添加了一个残差连接+layer normalization的操作。
模型的解码器同样是堆叠了N个相同的层,不过和编码器中每层的结构稍有不同。对于解码器的每一层,除了编码器中的两个子层Multi-Head Attention
和Feed Forward
,解码器还包含一个子层Masked Multi-Head Attention
,如图中所示每个子层同样也用了residual以及layer normalization。
模型的输入由Input Embedding
和Positional Encoding
(位置编码)两部分组合而成,模型的输出由Decoder的输出简单的经过softmax得到。
结合上图,我们对Transformer模型的结构做了个大致的梳理,只需要先有个初步的了解,下面对提及的每个模块进行详细介绍。
2 模型输入
首先我们来看模型的输入是什么样的,先明确模型输入,后面的模块理解才会更直观。
输入部分包含两个模块,Embedding
和 Positional Encoding
。
2.1 Embedding层
Embedding层的作用是将某种格式的输入数据,例如文本,转变为模型可以处理的向量表示,来描述原始数据所包含的信息。
Embedding
层输出的可以理解为当前时间步的特征,如果是文本任务,这里就可以是Word Embedding
,如果是其他任务,就可以是任何合理方法所提取的特征。
构建Embedding层的代码很简单,核心是借助torch提供的nn.Embedding
,如下:
class Embeddings(nn.Module):
def __init__(self, d_model, vocab):
"""
类的初始化函数
d_model:指词嵌入的维度
vocab:指词表的大小
"""
super(Embeddings, self).__init__()
#之后就是调用nn中的预定义层Embedding,获得一个词嵌入对象self.lut
self.lut = nn.Embedding(vocab, d_model)
#最后就是将d_model传入类中
self.d_model =d_model
def forward(self, x):
"""
Embedding层的前向传播逻辑
参数x:这里代表输入给模型的单词文本通过词表映射后的one-hot向量
将x传给self.lut并与根号下self.d_model相乘作为结果返回
"""
embedds = self.lut(x)
return embedds * math.sqrt(self.d_model)
2.2 位置编码:
Positional Encodding
位置编码的作用是为模型提供当前时间步的前后出现顺序的信息。因为Transformer不像RNN那样的循环结构有前后不同时间步输入间天然的先后顺序,所有的时间步是同时输入,并行推理的,因此在时间步的特征中融合进位置编码的信息是合理的。
位置编码可以有很多选择,可以是固定的,也可以设置成可学习的参数。
这里,我们使用固定的位置编码。具体地,使用不同频率的sin和cos函数来进行位置编码,如下所示:
其中pos代表时间步的下标索引,向量
也就是第pos个时间步的位置编码,编码长度同Embedding
层,这里我们设置的是512。上面有两个公式,代表着位置编码向量中的元素,奇数位置和偶数位置使用两个不同的公式。
思考:为什么上面的公式可以作为位置编码?
我的理解:在上面公式的定义下,时间步p和时间步p+k的位置编码的内积,即