Bootstrap

Transformer训练过程、推理过程详解

系列文章目录



前言

 Transformer是一种很重要且强大的模型,如今网上介绍Transformer的文章很多,但是大都是只是介绍模型的结构,对于训练过程、推理过程模型的输入和输出都没有进行介绍,使人看的一头雾水,因为本文就Transformer的模型的训练和推理过程进行详细的介绍。

一、模型结构

 下面的Transformer的模型结构图,其分为编码器和解码器两部分。模型的结构详解网上有很多,在这里就不和大家赘述了。关于位置编码、自注意力机制和掩码部分的详细介绍可以参考这两篇博文。
   Transformer输入Embedding及位置编码详解
   Transformer多头自注意力及掩码机制详解
在这里插入图片描述

在这里插入图片描述

二、训练过程

 比如我们想做一个中译英的机器翻译任务,现在我们已经准备好了数据集,例如{‘chinese’:我爱吃梨,‘english’: i love eating pears} ,那么模型的输入和输出以及标签分别是什么呢?
首先我们会先根据数据集创建中文词表和英文词表,假如数据集中只有{‘chinese’:我爱吃梨,‘english’: i love eating pears} 这一条样本,那么中文词表为{‘我’:0,‘爱’:1,‘吃’:2,‘梨’:3},英文词表为{‘PAD’:0,‘BOS’:1,‘EOS’:2,‘i’:3,‘love’:4,‘eating’:5,‘pears’:6}。为什么英文词表中还多了一些特殊字符呢?首先’PAD’字符是填充的意思,在进行输入的时候为了保持一个batch中的句子长度一样,会将句子填充到一样的长度,所以如果句子的长度过短,会在其后面增加填充符。’BOS‘字符是开始符号,解码器的输入中的第一个字符便是’BOS’字符;’EOS‘字符为结束字符。
创建好此表后,我们根据词表将中英文输入映射成数字。如下所示:

encoder输入:
我爱吃梨---------------------------------------------------[0,1,2,3]
decoder输入:
‘BOS’ i love eating pears ------------------------------[1,3,4,5,6]
标签:
i love eating pears ‘EOS’ ------------------------------[3,4,5,6,2]
decoder输出:(模型还没有训练好)
i love pears pears ‘EOS’------------------------------------[3,4,6,6,2]

 可以看到,编码器的输入是一整个句子,解码器的输入和输出也是完整的一个句子,最后根据解码器输出的句子去和标签去算loss。因此transformer在训练的时候是并行的。
 为什么说是并行的呢?例如解码器输入是’‘BOS’ i love eating pears‘,输出是’i love pears pears ‘EOS’‘,因此可以看出,输入中的’BOS’对应解码器预测输出’i’, 输入中的‘i’对应解码器的预测输出中的’love’。与自回归方式不同(把解码器的输出当作输入去预测,一个字一个字的输出),Transformer我们可以想象成给解码器扔进去一句话,然后解码器同时也输出了一句话。因此Transformer的训练速度相比与传统的RNN也快了很多。同时也因为是并行的,所以对内存要求也大大增加。
 但是很明显,如果训练的时候给解码器输入一个完整的句子,例如输入’‘BOS’ i love eating pears‘,训练模型输出i love pears pears ‘EOS’,这对于模型来说没什么好学的呀,我直接把你输入的句子去掉‘BOS’,然后再结尾加上‘EOS’不就可以了。这不就相当于平时开卷做作业抄答案回回满分,一到真正的闭卷考试就考不好了。因此,为了让模型再预测当前单词的输出不提前看到后面的”答案“。我们需要加上因果mask,把当前时刻的词后面的语句给掩盖掉,从而让模型不是通过去抄答案获得最终的输出。

三、推理过程

 在训练的时候,模型解码器的输入和输出都是一整个句子,通过对输入的句子加入因果mask实现并行训练。那如果已经训练好了一个Transformer模型,使用它做机器翻译推理的时候,我们只有编码器的输入,那么解码器的输入输出是什么呢?

encoder输入:
我爱吃梨 (Transformer编码器只需要编码一次就可以)

解码器推理次数decoder输入:decoder输出:
1‘BOS’i
2‘BOS’ ii love
3‘BOS’ i lovei love eatting
4‘BOS’ i love eattingi love eatting pears
5‘BOS’ i love eatting pearsi love eatting pears ‘EOS’

 通过上面的例子可以看出,在推理过程中transformer的输出采用的是自回归的方式,即将transformer解码器的输出当作输入然后重新去做推理,然后得到下一个输出,直到输出’EOS’。在推理时,会先给解码器一个’BOS’的字作为第一个输入,然后开始自回归输出。这也就是为什么训练时要在输入中加入’BOS’。
注意:由于Transformer结构的特点,其解码器编码器的输入维度和输出维度相同,所以在推理的时候,解码器输入几个字,其就输出几个字,且输出的字都是你提前定义的词表中的字。

;