2021SC@SDUSC
目录
2021SC@SDUSC
normal_bert.py 代码分析
代码输入包含七个部分,分别为
input_ids,input_mask,segment_ids,masked_lm_positions,mask_lm_ids,masked_lm_weights,next_sentence_labels.
input_ids:表示tokens的ids
input_mask:表示哪些是input,哪些是padding.len(input_ids)个1,后面继续补0.对于mask的词,主要占了全部vocabulary的15%左右,在代码中对于每个词80%replace with [mask],10% keep original,10% replace with random word.超过了mask的词数,则终止.
segment_ids:第一个句子到[SEP]为0,后面为1.主要是对输入进行区分,判断输入的两个句子.
masked_lm_positions:表示句子中mask的token的position.
mask_lm_ids:表示句子中mask的token的id.
masked_lm_weights:表示句子中mask的token的权重.
next_sentence_labels:表示两个句子是不是相连的.
代码示例
class ClassificationBert(nn.Module):
def __init__(self, num_labels=2):
super(ClassificationBert, self).__init__()
加载预训练bert模