FEW-NERD: A Few-shot Named Entity Recognition Dataset
原文与代码链接:https://ningding97.github.io/fewnerd/
一、项目结构
├── Few-NERD
| └── data
| └── supervised # 数据集按照7:2:1的比例被随即划分为训练集、验证集和测试集,三个集合都包含66个细粒度实体类型(不适用于小样本研究)
| └── episode-data # 数据集存放路径
| └── inter # 按照细粒度实体进行分类,在每个粗粒度中,均随机挑选60%的细粒度实体类作为训练集,随机挑选20%、20%作为验证集和测试集;
| └── intra # 按照粗粒度实体进行分类,即训练集:People, MISC, Art, Product,验证集:Event, Building,测试集:ORG, LOC;
| └── model
| └── proto.py # 模型文件
| └── nnshot.py
| └── premodel # 预训练模型存放路径
| └── bert_base_uncased
| └── chinese-bert-wwm
| └── roberta-wwm-ext-base
| └── util # 工具类
| └── data_loader.py # 数据集加载
| └── fewshotsampler.py # 小样本构建
| └── framework.py #
| └── metric.py # 距离度量函数
| └── supervised_util.py #
| └── viterbi.py # 维特比解码器
| └── word_encoder.py # 编码器
| └── run_supervised.py # 有监督模式程序入口(非小样本模式)
| └── run_supervised.sh
| └── run_train.sh # 运行脚本
| └── train_demo.py # 程序入口
二、数据集结构分析
{
"support": {
"word": [
["it", "..."],
["richard","..."],
["it","..."],
["new","..."],
["the", "..."]
],
"label": [
["O","..."],
["O","..."],
["O","..."],
["O","..."],
["O","..."],
]
},
"query": {
"word": [
["this","..."],
["professor","..."],
["she","..."],
["spycraft","..."],
["it","..."]
],
"label": [
["O","person-scholar","..."],
[],
[],
[],
[]
]
},
"types": [
"product-game",
"organization-religion",
"art-painting",
"event-other",
"person-scholar"
]
}
每条样本包含support 与 query set,分别包含5条样本,support set与query set实体类型一致,types标签中表示五条样本中包含的实体类型。
三、代码分析
1.程序入口
python3 train_demo.py \
--mode inter \
--pretrain_ckpt ./premodel/bert_base_uncased \
--lr 1e-4 \
--batch_size 8 \
--trainN 5 \
--K 1 \
--Q 1 \
--train_iter 10000 \
--val_iter 500 \
--test_iter 5000 \
--val_step 1000 \
--max_length 128 \
--model proto \
--tau 0.32 # 维特比解码器中用到的温度系数τ (5-10 shot tau=0.434)
2.定义参数
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='inter',
help='training mode, must be in [inter, intra]')
parser.add_argument('--pretrainmodel', default='Ernie',
help='pretraining mode, must be in [Bert, Ernie]')
parser.add_argument('--trainN', default=5, type=int,
help='N in train')
parser.add_argument('--N', default=5, type=int,
help='N way')
parser.add_argument('--K', default=1, type=int,
help='K shot')
parser.add_argument('--Q', default=5, type=int,
help='Num of query per class')
parser.add_argument('--batch_size', default=4, type=int,
help='batch size')
parser.add_argument('--train_iter', default=600, type=int,
help='num of iters in training')
parser.add_argument('--val_iter', default=100, type=int,
help='num of iters in validation')
parser.add_argument('--test_iter', default=500, type=int,
help='num of iters in testing')
parser.add_argument('--val_step', default=20, type=int,
help='val after training how many iters')
parser.add_argument('--model', default='proto',
help='model name, must be proto, nnshot, or structshot')
parser.add_argument('--max_length', default=128, type=int,
help='max length')
parser.add_argument('--lr', default=1e-4, type=float,
help='learning rate')
parser.add_argument('--grad_iter', default=1, type=int,
help='accumulate gradient every x iterations')
parser.add_argument('--load_ckpt', default=None,
help='load ckpt')
parser.add_argument('--save_ckpt', default=None,
help='save ckpt')
parser.add_argument('--fp16', action='store_true',
help='use nvidia apex fp16')
parser.add_argument('--only_test', action='store_true',
help='only test')
parser.add_argument('--ckpt_name', type=str, default='',
help='checkpoint name.')
parser.add_argument('--seed', type=int, default=0,
help='random seed')
parser.add_argument('--ignore_index', type=int, default=-1,
help='label index to ignore when calculating loss and metrics')
parser.add_argument('--use_sampled_data', default=True, action='store_true',
help='use released sampled data, the data should be stored at "data/episode-data/" ')
# only for bert / roberta / ernie
parser.add_argument('--pretrain_ckpt', default=None,
help='bert / roberta / erniepre-trained checkpoint')
# only for prototypical networks
parser.add_argument('--dot', action='store_true',
help='use dot instead of L2 distance for proto')
# 若是store_true,则默认值是False。需手动指定该参数,该参数才为True。
# only for structshot
parser.add_argument('--tau', default=0.05, type=float,
help='StructShot parameter to re-normalizes the transition probabilities')
# experiment
parser.add_argument('--use_sgd_for_bert', action='store_true',
help='use SGD instead of AdamW for BERT.')
3.数据加载
train_data_loader = get_loader(opt.train, tokenizer,
N=trainN, K=K, Q=Q, batch_size=batch_size, max_length=max_length, ignore_index=opt.ignore_index, use_sampled_data=opt.use_sampled_data)
val_data_loader = get_loader(opt.dev, tokenizer,
N=N, K=K, Q=Q, batch_size=batch_size, max_length=max_length, ignore_index=opt.ignore_index, use_sampled_data=opt.use_sampled_data)
test_data_loader = get_loader(opt.test, tokenizer,
N=N, K=K, Q=Q, batch_size=batch_size, max_length=max_length, ignore_index=opt.ignore_index, use_sampled_data=opt.use_sampled_data)
传入参数:
opt.train/dev/test
:训练集/验证集/测试集路径;tokenizer
:加载预训练Bert模型的tokenizer,用于分词、转化、编码和预测等;N=5
: 每个episode包含的实体类型数目;K=1
:每个episode的支持集中每种实体类型的样本数目;Q=1
:每个episode的查询集中每种实体类型的样本数目;batch_size=8
:每个批次的样本数目;max_length= 128
:样本的最大长度,超过则进行截断;ignore_index=-1
:计算损失和指标时要忽略的标签索引,默认值为-1;
def get_loader(filepath, tokenizer, N, K, Q, batch_size, max_length,
num_workers=8, collate_fn=collate_fn, ignore_index=-1, use_sampled_data=True):
if not use_sampled_data:
dataset = FewShotNERDatasetWithRandomSampling(filepath, tokenizer, N, K, Q, max_length, ignore_label_id=ignore_index)
else:
dataset = FewShotNERDataset(filepath, tokenizer, max_length, ignore_label_id=ignore_index)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=num_workers,
collate_fn=collate_fn)
return data_loader
由于use_sampled_data = True
,故调用FewShotNERDataset
方法:
class FewShotNERDataset(FewShotNERDatasetWithRandomSampling):
def __init__(self, filepath, tokenizer, max_length, ignore_label_id=-1):
if not os.path.exists(filepath):
print("[ERROR] Data file does not exist!")
assert(0)
self.class2sampleid = {}
self.tokenizer = tokenizer
self.samples = self.__load_data_from_file__(filepath)
self.max_length = max_length
self.ignore_label_id = ignore_label_id
def __load_data_from_file__(self, filepath):
with open(filepath)as f:
lines = f.readlines()
for i in range(len(lines)):
lines[i] = json.loads(lines[i].strip()) # 去除首位字符
return lines
def __additem__(self, d, word, mask, text_mask, label):
……
def __get_token_label_list__(self, words, tags):
……
def __populate__(self, data, savelabeldic=False):
……
def __getitem__(self, index):
……
def __len__(self):
return len(self.samples)
__init__()
构造方法:调用类时,构造方法先执行。但其他的实例方法不会执行,若需调用其他的实例方法,需要在构造方法中调用 或 在类的外部用实例化对象显式调用。
4.加载原型网络
创建一个Proto模型实例,并创建一个FewShotNERFramework实例,传入训练数据加载器、验证数据加载器、测试数据加载器作为参数。
if model_name == 'proto':
print('use proto')
model = Proto(word_encoder, dot=opt.dot, ignore_index=opt.ignore_index)
framework = FewShotNERFramework(train_data_loader, val_data_loader, test_data_loader, use_sampled_data=opt.use_sampled_data)
Proto类
class Proto(util.framework.FewShotNERModel):
def __init__(self,word_encoder, dot=False, ignore_index=-1):
util.framework.FewShotNERModel.__init__(self, word_encoder, ignore_index=ignore_index)
self.drop = nn.Dropout()
self.dot = dot
构造函数,用于初始化对象。它接受三个参数:word_encoder(词编码器),dot(一个布尔值,表示是否使用点积计算距离),ignore_index(忽略的索引)。在这个方法中,首先调用父类的构造函数进行初始化,然后创建一个dropout层和一个nn.Dropout实例。
def __dist__(self, x, y, dim):
if self.dot:
return (x * y).sum(dim)
else:
return -(torch.pow(x - y, 2)).sum(dim)
def __batch_dist__(self, S, Q, q_mask):
# S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim]
assert Q.size()[:2] == q_mask.size()
Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim]
return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2)
def __get_proto__(self, embedding, tag, mask):
proto = []
embedding = embedding[mask==1].view(-1, embedding.size(-1))
tag = torch.cat(tag, 0)
assert tag.size(0) == embedding.size(0)
for label in range(torch.max(tag)+1):
proto.append(torch.mean(embedding[tag==label], 0))
proto = torch.stack(proto)
return proto
__dist__()
:方法用于计算两个向量之间的距离。如果dot为True,则使用点积计算距离;否则,使用欧氏距离计算距离;
__batch_dist__()
:用于计算支持集和查询集中每个类别的原型与查询集中所有文本标记之间的距离;
__get_proto__()
:用于获取每个类别的原型。它接受三个参数:embedding(嵌入),tag(标签)和mask(掩码)。在这个方法中,首先根据掩码获取嵌入和标签,然后对每个类别计算其原型,即该类别的所有文本标记的平均值。
def forward(self, support, query):
'''
support: Inputs of the support set.
query: Inputs of the query set.
N: Num of classes
K: Num of instances for each class in the support set
Q: Num of instances in the query set
'''
support_emb = self.word_encoder(support['word'], support['mask']) # [num_sent, number_of_tokens, 768]
query_emb = self.word_encoder(query['word'], query['mask']) # [num_sent, number_of_tokens, 768]
support_emb = self.drop(support_emb) # 对支持集的嵌入进行dropout操作,以减少过拟合
query_emb = self.drop(query_emb)
# Prototypical Networks
logits = []
current_support_num = 0
current_query_num = 0
assert support_emb.size()[:2] == support['mask'].size()
assert query_emb.size()[:2] == query['mask'].size()
for i, sent_support_num in enumerate(support['sentence_num']):
sent_query_num = query['sentence_num'][i] # 获取当前句子对应的查询集中的实例数量。
# Calculate prototype for each class
support_proto = self.__get_proto__(
support_emb[current_support_num:current_support_num+sent_support_num],
support['label'][current_support_num:current_support_num+sent_support_num],
support['text_mask'][current_support_num: current_support_num+sent_support_num])
# calculate distance to each prototype
logits.append(self.__batch_dist__(
support_proto,
query_emb[current_query_num:current_query_num+sent_query_num],
query['text_mask'][current_query_num: current_query_num+sent_query_num])) # [num_of_query_tokens, class_num]
current_query_num += sent_query_num
current_support_num += sent_support_num
# print('logits', logits)
logits = torch.cat(logits, 0) # 将logits列表中的所有元素连接成一个张量
_, pred = torch.max(logits, 1) # 找到距离最大的类别作为预测结果
return logits, pred
assert support_emb.size()[:2] == support['mask'].size()
assert query_emb.size()[:2] == query['mask'].size()
确保支持集与查询集的嵌入和掩码具有相同的形状
support_proto = self.__get_proto__(
support_emb[current_support_num:current_support_num+sent_support_num],
support['label'][current_support_num:current_support_num+sent_support_num],
support['text_mask'][current_support_num: current_support_num+sent_support_num])
logits.append(self.__batch_dist__(
support_proto,
query_emb[current_query_num:current_query_num+sent_query_num],
query['text_mask'][current_query_num: current_query_num+sent_query_num]))
调用__get_proto__
方法计算当前类别的原型,调用__batch_dist__
方法计算查询集中每个实例与当前类别原型之间的距离,并将结果添加到logits列表中。
5.FewShotNERFramework
框架类用于训练、验证和测试基于少样本学习的 NER 模型,
FewShotNERFramework 类的构造函数接受若干个参数:train_data_loader、val_data_loader、test_data_loader 、
viterbi等。
class FewShotNERFramework:
def __init__(self, train_data_loader, val_data_loader, test_data_loader, viterbi=False, N=None, train_fname=None, tau=0.05, use_sampled_data=True):
'''
train_data_loader: DataLoader for training.
val_data_loader: DataLoader for validating.
test_data_loader: DataLoader for testing.
'''
self.train_data_loader = train_data_loader
self.val_data_loader = val_data_loader
self.test_data_loader = test_data_loader
self.viterbi = viterbi
if viterbi:
abstract_transitions = get_abstract_transitions(train_fname, use_sampled_data=use_sampled_data)
self.viterbi_decoder = ViterbiDecoder(N+2, abstract_transitions, tau)
其中,train_data_loader、val_data_loader 和 test_data_loader 是用于加载训练、验证和测试数据的数据加载器。
viterbi 是一个布尔值,表示是否使用维特比解码。如果设置为 True,则需要进一步初始化 viterbi_decoder 对象。
get_abstract_transitions()
函数用于计算标签序列的抽象转移概率。它接受 train_fname 和 use_sampled_data 两个参数。如果 use_sampled_data 设置为 True,则从 train_fname 加载小样本数据,并从中提取标签序列。否则,从 train_fname 加载完整数据集,并从中提取标签序列。
def get_abstract_transitions(train_fname, use_sampled_data=True):
"""
Compute abstract transitions on the training dataset for StructShot
"""
if use_sampled_data:
samples = data_loader.FewShotNERDataset(train_fname, None, 1).samples
tag_lists = []
for sample in samples:
tag_lists += sample['support']['label'] + sample['query']['label']
else:
samples = data_loader.FewShotNERDatasetWithRandomSampling(train_fname, None, 1, 1, 1, 1).samples
tag_lists = [sample.tags for sample in samples]
s_o, s_i = 0., 0.
o_o, o_i = 0., 0.
i_o, i_i, x_y = 0., 0., 0.
for tags in tag_lists:
if tags[0] == 'O': s_o += 1
else: s_i += 1
for i in range(len(tags)-1):
p, n = tags[i], tags[i+1]
if p == 'O':
if n == 'O': o_o += 1
else: o_i += 1
else:
if n == 'O':
i_o += 1
elif p != n:
x_y += 1
else:
i_i += 1
trans = []
trans.append(s_o / (s_o + s_i))
trans.append(s_i / (s_o + s_i))
trans.append(o_o / (o_o + o_i))
trans.append(o_i / (o_o + o_i))
trans.append(i_o / (i_o + i_i + x_y))
trans.append(i_i / (i_o + i_i + x_y))
trans.append(x_y / (i_o + i_i + x_y))
return trans
- 首先,函数根据数据加载方式(小样本数据或完整数据集)获取样本列表
samples
; - 然后,根据样本列表生成标签列表
tag_lists
。对于小样本数据加载方式,直接从样本中提取支持集和查询集的标签。对于完整数据集加载方式,遍历所有样本,从中提取每个样本的标签; - 接着,函数初始化并更新用于计算抽象转移概率的统计变量。具体地,对于每个标签序列:统计标签序列起始为 O 和 I 的次数;统计标签序列从 O 到 O 和从 O 到 I 的次数;统计标签序列从 I 到 O、从 I 到 I 和标签序列中不同标签相邻的次数;
- 最后,函数计算并返回标签序列的抽象转移概率列表
trans
。其中,trans
列表中的每个元素表示一个抽象转移概率。
load_model
方法从给定路径加载模型检查点。它首先检查路径是否是一个文件,如果是,则使用 torch.load 函数加载检查点,并返回结果。如果路径不是一个文件,则抛出一个异常。
def __load_model__(self, ckpt):
'''
ckpt: Path of the checkpoint
return: Checkpoint dict
'''
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt)
print("Successfully loaded checkpoint '%s'" % ckpt)
return checkpoint
else:
raise Exception("No checkpoint found at '%s'" % ckpt)
item
方法用于从张量中获取单个标量值。它根据当前使用的 PyTorch 版本来返回正确的标量值。在 PyTorch 0.4 版本之前,可以直接使用 [0] 索引来获取标量值。但在 PyTorch 0.4 版本及以后,需要使用 item() 方法来获取标量值。
def item(self, x):
'''
PyTorch before and after 0.4
'''
torch_version = torch.__version__.split('.')
if int(torch_version[0]) == 0 and int(torch_version[1]) < 4:
return x[0]
else:
return x.item()
train
定义一个训练循环,用于在深度学习模型的训练过程中进行迭代。
def train(self,
model,
model_name,
learning_rate=1e-1,
train_iter=30000,
val_iter=1000,
val_step=2000,
load_ckpt=None,
save_ckpt=None,
warmup_step=300,
grad_iter=1,
fp16=False,
use_sgd_for_bert=False):
'''
model: a FewShotREModel instance
model_name: Name of the model
B: Batch size
N: Num of classes for each batch
K: Num of instances for each class in the support set
Q: Num of instances for each class in the query set
ckpt_dir: Directory of checkpoints
learning_rate: Initial learning rate
train_iter: Num of iterations of training
val_iter: Num of iterations of validating
val_step: Validate every val_step steps
'''
print("Start training...")
# Init optimizer
parameters_to_optimize = list(model.named_parameters()) # 获取模型的所有可优化的参数
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] # 列表中包含不需要衰减的参数名称
# 将需要衰减和不需要衰减的参数分别放入两个不同的字典中,每个字段都包含参数和对应的权重衰减值。
parameters_to_optimize = [
{'params': [p for n, p in parameters_to_optimize
if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in parameters_to_optimize
if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
# 选择SGD或者AdamW优化器
if use_sgd_for_bert:
optimizer = torch.optim.SGD(parameters_to_optimize, lr=learning_rate)
else:
optimizer = AdamW(parameters_to_optimize, lr=learning_rate, correct_bias=False)
# 创建一个学习率调度器,用于训练过程中调整学习率
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=train_iter)
# load model 如果提供了检查点路径,则从检查点加载模型的状态字典,并将其复制到当前模型的状态字典中。
if load_ckpt:
state_dict = self.__load_model__(load_ckpt)['state_dict']
own_state = model.state_dict()
for name, param in state_dict.items():
if name not in own_state:
print('ignore {}'.format(name))
continue
print('load {} from {}'.format(name, load_ckpt))
own_state[name].copy_(param)
# 如果启用了混合精度训练(fp16),则使用Apex库进行初始化。
if fp16:
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
# 将模型设置为训练模式
model.train()
# Training
best_f1 = 0.0
iter_loss = 0.0
iter_sample = 0
pred_cnt = 0
label_cnt = 0
correct_cnt = 0
it = 0
# 开始训练循环,直到达到指定的训练迭代次数train_iter
while it + 1 < train_iter:
for _, (support, query) in enumerate(self.train_data_loader): # 通过self.train_data_loader依次获取一对训练数据(support, query)
label = torch.cat(query['label'], 0) # 将查询集数据的标签label拼接起来
# 若可以获取到cuda资源,则将数据和标签转移到GPU上
if torch.cuda.is_available():
for k in support:
if k != 'label' and k != 'sentence_num':
support[k] = support[k].cuda()
query[k] = query[k].cuda()
label = label.cuda()
# 调用模型model对数据进行前向传播,得到预测结果logits和pred
logits, pred = model(support, query)
# 断言,若两者相等,则输出两者的形状维度;不等则报错
assert logits.shape[0] == label.shape[0], print(logits.shape, label.shape)
# 使用模型的损失函数model.loss,计算损失值loss,除以grad_iter取平均得到损失均值
loss = model.loss(logits, label) / float(grad_iter)
# 根据预测结果和标签计算模型评估指标,预测正确数量correct、预测样本数量tmp_pred_cnt、标签样本数量tmp_label_cnt
tmp_pred_cnt, tmp_label_cnt, correct = model.metrics_by_entity(pred, label)
# 根据是否启用混合精度训练,调用不同的反向传播方法计算梯度
if fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# 若满足it % grad_iter==0, 则进行一次梯度更新,同时调用优化器的step()方法更新模型参数
if it % grad_iter == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# 对迭代的损失值、预测样本数量、标签样本数量、正确预测的数量等进行累加
iter_loss += self.item(loss.data)
#iter_right += self.item(right.data)
pred_cnt += tmp_pred_cnt
label_cnt += tmp_label_cnt
correct_cnt += correct
iter_sample += 1 # 迭代次数累加
# 如果 (it + 1) % 100 == 0 或 (it + 1) % val_step == 0,则计算精度、召回率和 F1 值,并在控制台打印输出。
if (it + 1) % 100 == 0 or (it + 1) % val_step == 0:
precision = correct_cnt / pred_cnt
recall = correct_cnt / label_cnt
f1 = 2 * precision * recall / (precision + recall)
sys.stdout.write('step: {0:4} | loss: {1:2.6f} | [ENTITY] precision: {2:3.4f}, recall: {3:3.4f}, f1: {4:3.4f}'\
.format(it + 1, iter_loss/ iter_sample, precision, recall, f1) + '\r')
# 写入文件
train_item = {
"step": it + 1,
"loss": iter_loss/ iter_sample,
"precision": precision,
"recall": recall,
"f1": f1
}
with open('../result/train_result.txt', 'a', encoding='utf-8') as file:
json.dump(train_item, file, ensure_ascii=False)
file.write('\n')
sys.stdout.flush() # 实时将缓冲区的内容输出
# 每间隔val_step次迭代进行一次验证eval
if (it + 1) % val_step == 0:
_, _, f1, _, _, _, _ = self.eval(model, val_iter)
model.train()
# 以f1值作为指标,保存f1值最高的模型
if f1 > best_f1:
print('Best checkpoint')
torch.save({'state_dict': model.state_dict()}, save_ckpt)
best_f1 = f1
iter_loss = 0.
iter_sample = 0.
pred_cnt = 0
label_cnt = 0
correct_cnt = 0
if (it + 1) == train_iter:
break
it += 1
print("\n####################\n")
print("Finish training " + model_name)
eval
验证FewShotNERModel模型。主要功能包括加载模型参数、执行评估、计算实体精度、召回率、F1值和错误分析。函数首先将模型设置为验证模式,然后根据给定的检查点路径加载数据集(验证集或测试集),加载模型参数,并准备进行评估。在验证过程中,它会迭代数据集,计算预测实体数量、真实标签实体数量、正确预测的实体数量以及各种错误类型的计数。最后,它计算并返回实体精度、召回率、F1值,以及各种错误比率。
def eval(self,
model,
eval_iter,
ckpt=None):
'''
model: a FewShotNERModel instance
B: Batch size
N: Num of classes for each batch
K: Num of instances for each class in the support set
Q: Num of instances for each class in the query set
eval_iter: Num of iterations
ckpt: Checkpoint path. Set as None if using current model parameters.
return: Accuracy
'''
model.eval() # 将模型设置为验证模式
if ckpt is None:
print("Use val dataset")
eval_dataset = self.val_data_loader
else:
print("Use test dataset")
if ckpt != 'none':
state_dict = self.__load_model__(ckpt)['state_dict']
own_state = model.state_dict()
for name, param in state_dict.items():
if name not in own_state:
continue
own_state[name].copy_(param)
eval_dataset = self.test_data_loader
pred_cnt = 0 # pred entity cnt
label_cnt = 0 # true label entity cnt
correct_cnt = 0 # correct predicted entity cnt
fp_cnt = 0 # misclassify O as I-
fn_cnt = 0 # misclassify I- as O
total_token_cnt = 0 # total token cnt
within_cnt = 0 # span correct but of wrong fine-grained type
outer_cnt = 0 # span correct but of wrong coarse-grained type
total_span_cnt = 0 # span correct
eval_iter = min(eval_iter, len(eval_dataset))
# 不需要进行梯度计算
with torch.no_grad():
it = 0
# 开始验证循环,直到达到指定的验证迭代次数eval_iter
while it + 1 < eval_iter:
for _, (support, query) in enumerate(eval_dataset): # 在eval_dataset中依次获取一对训练数据(support, query)
# 将查询集数据的标签label拼接起来
label = torch.cat(query['label'], 0)
# 若cuda可用,则将支持集 查询集 标签数据转移到GPU中
if torch.cuda.is_available():
for k in support:
if k != 'label' and k != 'sentence_num':
support[k] = support[k].cuda()
query[k] = query[k].cuda()
label = label.cuda()
# 调用模型model对数据进行前向传播,得到预测结果logits和pred
logits, pred = model(support, query)
# 调用维特比解码器解码
if self.viterbi:
pred = self.viterbi_decode(logits, query['label'])
viterbi_decode
实现维特比解码,找到最优的标签序列。
def viterbi_decode(self, logits, query_tags):
emissions_list = self.__get_emmissions__(logits, query_tags)
pred = []
for i in range(len(query_tags)):
sent_scores = emissions_list[i].cpu() # 将发射矩阵sent_scores转移到CPU上,并获取其形状信息。
sent_len, n_label = sent_scores.shape
sent_probs = F.softmax(sent_scores, dim=1) # 使用softmax函数计算得分的概率分布,得到sent_probs
start_probs = torch.zeros(sent_len) + 1e-6 # 创建一个全零的起始概率start_probs
sent_probs = torch.cat((start_probs.view(sent_len, 1), sent_probs), 1) # 并将其与sent_probs 在列维度上连接,以便进行 Viterbi 解码所需的输入格式构造。
# 将概率取对数,然后调用 Viterbi 解码器,得到预测的标签序列 vit_labels
feats = self.viterbi_decoder.forward(torch.log(sent_probs).view(1, sent_len, n_label+1))
vit_labels = self.viterbi_decoder.viterbi(feats)
vit_labels = vit_labels.view(sent_len)
vit_labels = vit_labels.detach().cpu().numpy().tolist()
# 将预测的标签(需要减去1以补偿添加的起始标签)附加到 pred 序列中
for label in vit_labels:
pred.append(label-1)
# 返回经过处理的预测标签序列,这些标签被转移到 GPU 上并封装成 Tensor 格式,并在代码的最后被返回。
return torch.tensor(pred).cuda()
- 首先,
viterbi_decode
函数接收两个输入 logits 和 query_tags。logits是模型的输出,表示对每个标记的预测得分;query_tags 是输入序列的标记集合; - 然后,函数中的
__get_emmissions__
方法用于获取发射矩阵,即模型对每个标记的发射得分;
# 获取发射矩阵
def __get_emmissions__(self, logits, tags_list):
# split [num_of_query_tokens, num_class] into [[num_of_token_in_sent, num_class], ...]
emmissions = []
current_idx = 0
for tags in tags_list:
emmissions.append(logits[current_idx:current_idx+len(tags)])
current_idx += len(tags)
assert current_idx == logits.size()[0]
return emmissions
最终,返回经过处理的预测标签序列,这些标签被转移到 GPU 上并封装成 Tensor 格式,并在代码的最后被返回。
Viterbi 解码用于在给定模型输出的情况下,找到最可能的标记序列,这对于序列标注任务(如命名实体识别、词性标注等)是至关重要的。
# 计算预测值pred和真实值label之间的一些度量指标,包含临时预测数tmp_pred_cnt、临时标签数tmp_label_cnt、正确预测数correct
tmp_pred_cnt, tmp_label_cnt, correct = model.metrics_by_entity(pred, label)
# error_analysis分析预测值pred和真实标签label之间的错误,
# 包括假阳性fp,假阴性fn,标记总数token_cnt、内部错误within、外部错误outer、总跨度total_span
fp, fn, token_cnt, within, outer, total_span = model.error_analysis(pred, label, query)
# 累加
pred_cnt += tmp_pred_cnt
label_cnt += tmp_label_cnt
correct_cnt += correct
fn_cnt += self.item(fn.data)
fp_cnt += self.item(fp.data)
total_token_cnt += token_cnt
outer_cnt += outer
within_cnt += within
total_span_cnt += total_span
if it + 1 == eval_iter:
break
it += 1
# 计算评估指标
epsilon = 1e-6
precision = correct_cnt / (pred_cnt + epsilon)
recall = correct_cnt / (label_cnt + epsilon)
f1 = 2 * precision * recall / (precision + recall + epsilon)
fp_error = fp_cnt / total_token_cnt
fn_error = fn_cnt / total_token_cnt
within_error = within_cnt / (total_span_cnt + epsilon)
outer_error = outer_cnt / (total_span_cnt + epsilon)
sys.stdout.write('[EVAL] step: {0:4} | [ENTITY] precision: {1:3.4f}, recall: {2:3.4f}, f1: {3:3.4f}'.format(it + 1, precision, recall, f1) + '\r')
dev_item = {
"step": it + 1,
"precision": precision,
"recall": recall,
"f1": f1
}
with open('../result/dev_result.txt', 'a', encoding='utf-8') as file:
json.dump(dev_item, file, ensure_ascii=False)
file.write('\n')
sys.stdout.flush() # 刷新缓冲区
return precision, recall, f1, fp_error, fn_error, within_error, outer_error