Transformer主要由两个部分组成,即Encoder和Decoder,Encoder由n个Self.Attention和一个feedforward层组成,那么接下来是Self.Attention层的代码
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention,self).__init__() # 继承父类 nn.Module 的 __init__ 方法
self.embed_size = embed_size # 嵌入向量的维度
self.heads = heads # 多头注意力机制的头数
self.head_dim = embed_size // heads # 每个头的维度
# 确保头的维度乘以头数等于嵌入向量的维度
assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"
# 初始化线性层,用于计算值、键、查询
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
# 初始化线性层,用于将多头注意力机制的输出连接起来
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
在__init__的时候,我们要定义传入的heads数量,即多头注意力个数,随后传入embed_size用来标记初始词向量表示空间的大小。最后利用一个头的维度空间就是用,head_dim = embed_size // heads来表示。因为传入self.attention的层的qkv需要经过线性变换,这里初始化的时候给出了可学习的参数矩阵。self.fc_out表示将n个头的矩阵拼接起来,最后经过一个线性变换得到一个大小为embed_size的数据。
def forward(self, values, keys, query, mask):
N = query.shape[0] # 批次大小
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# 将嵌入向量拆分为多个头
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
# 计算能量(注意力得分),通过计算查询和键的点积得到
energy = torch.einsum("nqhd,nkhd->nhqk", [queries,keys])
# 如果存在掩码,则在能量上应用掩码
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
# 计算注意力权重,通过将能量除以嵌入向量的维度的平方根,并应用 softmax 函数
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=-1)
# 计算输出,通过计算注意力权重和值的点积,然后重新调整形状
out = torch.einsum("nhql,nlhd->nqhd",[attention, values]).reshape(
N, query_len, self.heads*self.head_dim
)
# 通过线性层,将多头注意力机制的输出连接起来
out = self.fc_out(out)
return out # 返回输出
具体来说,在初始化SelfAttention后,我们会传入values,keys,query三个数据(经过position embedding和word embedding)还有mask。这里要注意values,keys,query三个数据的大小[batch_size,seq_len,embed_size],这里的batch_size表示批量大小,seq_len表示具体句子里面单词的个数,embed_size表示每个单词由多少维度的特征表示。传入的数据要判断是否经过掩码,之后我们将得到的energy缩放,得到attention,然后attention和对应的values进行计划,得到最后的输出。最后注意,由于由有个头,我们输出的时候要将heads个头拼接起来,得到最后一个输出。