手写+分析bert
目录
前言
Attention is all you need!
读本文前,建议至少看懂【模型学习之路】手写+分析Transformer-CSDN博客。
毕竟Bert是transformer的变种之一。
架构
embeddings
Bert可以说就是transformer的Encoder,就像训练卷积网络时可以利用现成的网络然后fine tune就投入使用一样,Bert的动机就是训练一种预训练模型,之后根据不同的场景可以做不同的fine tune。
这里我们还是B代表批次(对于Bert,一个Batch可以输入一到两个句子,输入两个句子时,两个直接拼接就好了),m代表一个batch的单词数,n表示词向量的长度。
Bert的输入是三种输入之和(维度设定我们与本系列上一篇文章保持相同):
token_embeddings 和Transformer完全一样。
segment_embeddings 用来标记句子。第一个句子每个单词标0,第二个句子的每个单词标1。
pos_embeddings 用来标记位置,维度和Transformer中的一样,但是Bert的pos_embeddings是训练出来的(这意味它成为了神经网络里要训练的参数了)。
def get_token_and_segments(tokens_a, tokens_b=None):
"""
bert的输入之一:token embeddings
bert的输入之二:segment embeddings
pos_embeddings在后面的模型里面
"""
tokens = ['<cls>'] + tokens_a + ['<sep>']
segments = [0] * (len(tokens_a) + 2)
if tokens_b is not None:
tokens += tokens_b + ['<sep>']
segments += [1] * (len(tokens_b) + 1)
return tokens, segments
Bertmodel
Bert的单个EncoderLayer和Transformer是一样的,我们直接把上一节的代码复制过来就好。
组装好。
class BertModel(nn.Module):
def __init__(self, vocab, n, d_ff, h, n_layers,
max_len=1000, k=768, v=768):
super(BertModel, self).__init__()
self.token_embeddings = nn.Embedding(vocab, n) # [B, m]->[B, m, vocab]->[B, m, n]
self.segment_embeddings = nn.Embedding(2, n) # [B, m]->[B, m, 2]->[B, m, n]
self.pos_embeddings = nn.Parameter(torch.randn(1, max_len, n)) # [1, max_len, n]
self.layers = nn.ModuleList([EncoderLayer(n, h, k, v, d_ff)
for _ in range(n_layers)])
def forward(self, tokens, segments, m): # m是句子长度
X = self.token_embeddings(tokens) + \
self.segment_embeddings(segments)
X += self.pos_embeddings[:, :X.shape[1], :]
for layer in self.layers:
X, attn = layer(X)
return X
简单测试一下。
# 弄一点数据测试一下
tokens = torch.randint(0, 100, (2, 10)) # [B, m]
segments = torch.randint(0, 2, (2, 10)) # [B, m]
m = 10
bert = BertModel(100, 768, 3072, 12, 12)
out = bert(tokens, segments, m)
print(out.shape) # [2, 10, 768]
预训练任务
Bert在训练时要做两种训练,这里先画个图表示架构,后面给出分析和代码。
MLM
Maked language model,是指在训练的时候随即从输入预料上mask掉一些单词,然后通过的上下文预测该单词,该任务非常像我们在中学时期经常做的完形填空。
在BERT的实验中,15%的WordPiece Token会被随机Mask掉。在训练模型时,一个句子会被多次喂到模型中用于参数学习,但是Google并没有在每次都mask掉这些单词,而是在确定要Mask掉的单词之后,80%的时候会直接替换为[Mask],10%的时候将其替换为其它任意单词,10%的时候会保留原始Token。(这里就不深入了)
class MLM(nn.Module):
def __init__(self, vocab, n, mlm_hid):
super(MLM, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(n, mlm_hid),
nn.ReLU(),
nn.LayerNorm(mlm_hid),
nn.Linear(mlm_hid, vocab))
def forward(self, X, P):
# X: [B, m, n]
# P: [B, p]
# 这里P指的是记录了要mask的元素的矩阵,若P(i,j)==k,表示X(i,k)被mask了
p = P.shape[1]
P = P.reshape(-1)
batch_size = X.shape[0]
batch_idx = torch.arange(batch_size)
batch_idx = torch.repeat_interleave(batch_idx, p)
X = X[batch_idx, P].reshape(batch_size, p, -1) # [B, p, n]
out = self.mlp(X)
return out
这里的forward的逻辑有点麻烦,要读懂的话可以要手推一下。p是每一个Batch中mask的词的个数。(即在一个Batch中,m个词挑出了p个)。其实意会也行:就是训练时,在X[B, m, n]里每个batch有p个是<mask>,我们把它们挑出来,得到[B, p, n]。
NSP
Next Sentence Prediction的任务是判断句子B是否是句子A的下文。训练数据的生成方式是从平行语料中随机抽取的连续两句话,其中50%保留抽取的两句话,它们符合is_next关系,另外50%的第二句话是随机从预料中提取的,不符合is_next关系。分别记为1 | 0。
这个关系由每个句子的第一个token——<cls>捕捉。
class NSP(nn.Module):
def __init__(self, n, nsp_hid):
super(NSP, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(n, nsp_hid),
nn.Tanh(),
nn.Linear(nsp_hid, 2))
def forward(self, X):
# X: [B, m, n]
X = X[:, 0, :] # [B, n]
out = self.mlp(X) # [B, 2]
return out
Bert
下面拼装Bert。
class Bert(nn.Module):
def __init__(self, vocab, n, d_ff, h, n_layers,
max_len=1000, k=768, v=768, mlm_feat=768, nsp_feat=768):
super(Bert, self).__init__()
self.encoder = BertModel(vocab, n, d_ff, h, n_layers, max_len, k, v)
self.mlm = MLM(vocab, n, mlm_feat)
self.nsp = NSP(n, nsp_feat)
def forward(self, tokens, segments, m, P=None):
X = self.encoder(tokens, segments, m)
mlm_out = self.mlm(X, P) if P is not None else None
nsp_out = self.nsp(X)
return X, mlm_out, nsp_out
后话
netron可视化
利用netron可视化。
test_tokens = torch.randint(0, 100, (2, 10)) # [B, m]
test_segments = torch.randint(0, 2, (2, 10)) # [B, m]
test_P = torch.tensor([[1, 2, 4, 6, 8], [1, 3, 4, 5, 6]])
test_m = 10
test_bert = Bert(100, 768, 3072, 12, 12)
test_X, test_mlm_out, test_nsp_out = test_bert(test_tokens, test_segments, test_m, test_P)
modelData = "./demo.pth"
torch.onnx.export(test_bert, (test_tokens, test_segments), modelData)
netron.start(modelData)
截取部分看一下。
code2flow可视化
code2flow可以可视化代码函数和类的相互调用关系。
code2flow.code2flow([r'代码路径.py'], '输出路径.svg')
这里生成的png,其实svg清晰得多。
fine tuning
Bert的精髓在于,Bert只是一个编码器(Encoder),经过MLM和NSP两个任务的训练之后,可以自己在它的基础上训练一个Decoder来输出特定的值、得到特定的效果。这也是Bert的神奇和魅力所在!通过两个任务训练出一个编码器,然后可以通过不同的Decoder达到各种效果!
持续探索Bert......