Bootstrap

半监督文本分类学习代码展示及最终总结

2021SC@SDUSC

目录

2021SC@SDUSC

2021SC@SDUSC

normal_bert.py 代码分析

python中torch.nn解析

实验结果

学期学习总结


​​​​​​​

2021SC@SDUSC

normal_bert.py 代码分析

代码输入包含七个部分,分别为
input_ids,input_mask,segment_ids,masked_lm_positions,mask_lm_ids,masked_lm_weights,next_sentence_labels.
input_ids:表示tokens的ids
input_mask:表示哪些是input,哪些是padding.len(input_ids)个1,后面继续补0.对于mask的词,主要占了全部vocabulary的15%左右,在代码中对于每个词80%replace with [mask],10% keep original,10% replace with random word.超过了mask的词数,则终止.
segment_ids:第一个句子到[SEP]为0,后面为1.主要是对输入进行区分,判断输入的两个句子.
masked_lm_positions:表示句子中mask的token的position.
mask_lm_ids:表示句子中mask的token的id.
masked_lm_weights:表示句子中mask的token的权重.
next_sentence_labels:表示两个句子是不是相连的.

代码示例

class ClassificationBert(nn.Module):
    def __init__(self, num_labels=2):
        super(ClassificationBert, self).__init__()

加载预训练bert模

;