Bootstrap

BiLSTM+CRF代码

BiLSTM+CRF文件

# 导入所需文件
import numpy as np
from sklearn.model_selection import ShuffleSplit
from data_utils import ENTITIES, Documents, Dataset, SentenceExtractor, make_predictions
# from data_utils import Evaluator
from evaluator import Evaluator
from gensim.models import Word2Vec

# 数据文件读取
data_dir = "./data/train"
ent2idx = dict(zip(ENTITIES, range(1, len(ENTITIES) + 1)))
idx2ent = dict([(v, k) for k, v in ent2idx.items()])

# 训练集,测试集切分与打乱
docs = Documents(data_dir=data_dir)
rs = ShuffleSplit(n_splits=1, test_size=20, random_state=2018)
train_doc_ids, test_doc_ids = next(rs.split(docs))
train_docs, test_docs = docs[train_doc_ids], docs[test_doc_ids]

# 模型参数赋值
num_cates = max(ent2idx.values()) + 1
sent_len = 64
vocab_size = 3000
emb_size = 100
sent_pad = 10
sent_extrator = SentenceExtractor(window_size=sent_len, pad_size=sent_pad)
train_sents = sent_extrator(train_docs)
test_sents = sent_extrator(test_docs)

train_data = Dataset(train_sents, cate2idx=ent2idx)
train_data.build_vocab_dict(vocab_size=vocab_size)

test_data = Dataset(test_sents, word2idx=train_data.word2idx, cate2idx=ent2idx)
vocab_size = len(train_data.word2idx)

# 构建词嵌入模型
w2v_train_sents = []
for doc in docs:
    w2v_train_sents.append(list(doc.text))
w2v_model = Word2Vec(w2v_train_sents, vector_size=emb_size)

w2v_embeddings = np.zeros((vocab_size, emb_size))
for char, char_idx in train_data.word2idx.items():
    if char in w2v_model.wv:
        w2v_embeddings[char_idx] = w2v_model.wv[char]
        
# 构建双向长短时记忆模型模型加crf模型
import keras
from keras.layers import Input, LSTM, Embedding, Bidirectional
from keras_contrib.layers import CRF
from keras.models import Model


def build_lstm_crf_model(num_cates, seq_len, vocab_size, model_opts=dict()):
    opts = {
        'emb_size': 256,
        'emb_trainable': True,
        'emb_matrix': None,
        'lstm_units': 256,
        'optimizer': keras.optimizers.Adam()
    }
    opts.update(model_opts)

    input_seq = Input(shape=(seq_len,), dtype='int32')
    if opts.get('emb_matrix') is not None:
        embedding = Embedding(vocab_size, opts['emb_size'], 
                              weights=[opts['emb_matrix']],
                              trainable=opts['emb_trainable'])
    else:
        embedding = Embedding(vocab_size, opts['emb_size'])
    x = embedding(input_seq)
    lstm = LSTM(opts['lstm_units'], return_sequences=True)
    x = Bidirectional(lstm)(x)
    crf = CRF(num_cates, sparse_target=True)
    output = crf(x)

    model = Model(input_seq, output)
    model.compile(opts['optimizer'], loss=crf.loss_function, metrics=[crf.accuracy])
    return model

# 双向长短时记忆模型+CRF条件随机场实例化
seq_len = sent_len + 2 * sent_pad
model = build_lstm_crf_model(num_cates, seq_len=seq_len, vocab_size=vocab_size, 
                             model_opts={'emb_matrix': w2v_embeddings, 'emb_size': 100, 'emb_trainable': False})
model.summary()

# 训练集,测试集形状
train_X, train_y = train_data[:]
print('train_X.shape', train_X.shape)
print('train_y.shape', train_y.shape)


# 双向长短时记忆模型与条件随机场模型训练
model.fit(train_X, train_y, batch_size=64, epochs=10)


# 模型预测
test_X, _ = test_data[:]
preds = model.predict(test_X, batch_size=64, verbose=True)
pred_docs = make_predictions(preds, test_data, sent_pad, docs, idx2ent)

# 输出评价指标
f_score, precision, recall = Evaluator.f1_score(test_docs, pred_docs)
print('f_score: ', f_score)
print('precision: ', precision)
print('recall: ', recall)

# 测试样本展示
sample_doc_id = list(pred_docs.keys())[3]
test_docs[sample_doc_id]


# 测试结果展示
pred_docs[sample_doc_id]

__init__.py

from .data_utils import *
from .evaluator import Evaluator

evaluator.py

class Evaluator(object):
    # 查看两者所属种类是否相同
    @staticmethod
    def check_match(ent_a, ent_b):
        return (ent_a.category == ent_b.category and
                max(ent_a.start_pos, ent_b.start_pos) < min(ent_a.end_pos, ent_b.end_pos))
    
    # 计算两个集合交集
    @staticmethod
    def count_intersects(ent_list_a, ent_list_b):
        num_hits = 0
        ent_list_b = ent_list_b.copy()
        for ent_a in ent_list_a:
            hit_ent = None
            for ent_b in ent_list_b:
                if Evaluator.check_match(ent_a, ent_b):
                    hit_ent = ent_b
                    break
            if hit_ent is not None:
                num_hits += 1
                ent_list_b.remove(hit_ent)
        return num_hits

    # 定义模型评价指标f1分数
    @staticmethod
    def f1_score(gt_docs, pred_docs):
        num_hits = 0
        num_preds = 0
        num_gts = 0
        for doc_id in gt_docs.doc_ids:
            gt_ents = gt_docs[doc_id].ents.ents
            pred_ents = pred_docs[doc_id].ents.ents
            num_gts += len(gt_ents)
            num_preds += len(pred_ents)
            num_hits += Evaluator.count_intersects(pred_ents, gt_ents)
        p = num_hits / num_preds
        r = num_hits / num_gts
        f = 2 * p * r / (p + r)
        return f, p, r

data_utils.py

import os
import math
import numpy as np
from collections import Counter, defaultdict
from itertools import groupby
from spacy import displacy

# 实体
ENTITIES = [
    "Amount", "Anatomy", "Disease", "Drug",
    "Duration", "Frequency", "Level", "Method",
    "Operation", "Reason", "SideEff", "Symptom",
    "Test", "Test_Value", "Treatment"
]

# 实体与颜色映射
COLORS = [
    '#7aecec','#bfeeb7','#feca74','#ff9561',
    '#aa9cfc','#c887fb','#9cc9cc','#ffeb80',
    '#ff8197','#ff8197','#f0d0ff','#bfe1d9',
    '#bfe1d9','#e4e7d2','#e4e7d2','#e4e7d2',
    '#e4e7d2','#e4e7d2'
]
COLOR_MAP = dict(zip([ent.upper() for ent in ENTITIES], COLORS[:len(ENTITIES)]))

# 定义被切分的句子的类:
#      text:句子的文本
#      doc_id:句子所述文档id
#      offset:句子相对文档的偏移距离
#      ents:句子包含的实体列表 
class Sentence(object):
    def __init__(self, doc_id, offset, text, ents):
        self.text = text
        self.doc_id = doc_id
        self.offset = offset
        self.ents = ents

#  内部魔法函数:以text显示类
    def __repr__(self):
        return self.text
#  内部魔法函数:按类的offset偏移距离对类进行排序
    def __gt__(self, other):
        return self.offset > other.offset

#  内部魔法函数:预测结果评估时,去除句子两端延申的部分
    def __getitem__(self, key):
        if isinstance(key, int):
            return self.text[key]
        if isinstance(key, slice):
            text = self.text[key]
            start = key.start or 0
            stop = key.stop or len(self.text)
            if start < 0:
                start += len(self.text)
            if stop < 0:
                stop += len(self.text)

            ents = self.ents.find_entities(start, stop).offset(-start)
            offset = self.offset + start
            return Sentence(self.doc_id, offset, text, ents)
#  内部函数:网页显示不同的实体以不同的颜色区分
    def _repr_html_(self):
        ents = []
        for ent in self.ents:
            ents.append({'start': ent.start_pos,
                         'end': ent.end_pos,
                         'label': ent.category})
        ex = {'text': self.text, 'ents': ents, 'title': None, 'settings': {}}
        return displacy.render(ex,
                               style='ent',
                               options={'colors': COLOR_MAP},
                               manual=True,
                               minify=True)

# 实体的定义和处理:实体id,种类,开始,结尾
class Entity(object):
    def __init__(self, ent_id, category, start_pos, end_pos, text):
        self.ent_id = ent_id
        self.category = category
        self.start_pos = start_pos
        self.end_pos = end_pos
        self.text = text

    def __gt__(self, other):
        return self.start_pos > other.start_pos

    def offset(self, offset_val):
        return Entity(self.ent_id,
                      self.category,
                      self.start_pos + offset_val,
                      self.end_pos + offset_val,
                      self.text)

    def __repr__(self):
        return '({}, {}, ({}, {}), {})'.format(self.ent_id,
                                               self.category,
                                               self.start_pos,
                                               self.end_pos,
                                               self.text)

# 实体的定义和处理:
class Entities(object):
    def __init__(self, ents):
        self.ents = sorted(ents)
        self.ent_dict = dict(zip([ent.ent_id for ent in ents], ents))

    def __getitem__(self, key):
        if isinstance(key, int) or isinstance(key, slice):
            return self.ents[key]
        else:
            return self.ent_dict.get(key, None)

    def offset(self, offset_val):
        ents = [ent.offset(offset_val) for ent in self.ents]
        return Entities(ents)

    def vectorize(self, vec_len, cate2idx):
        res_vec = np.zeros(vec_len, dtype=int)
        for ent in self.ents:
            res_vec[ent.start_pos: ent.end_pos] = cate2idx[ent.category]
        return res_vec

    def find_entities(self, start_pos, end_pos):
        res = []
        for ent in self.ents:
            if ent.start_pos > end_pos:
                break
            sp, ep = (max(start_pos, ent.start_pos), min(end_pos, ent.end_pos))
            if ep > sp:
                new_ent = Entity(ent.ent_id, ent.category, sp, ep, ent.text[:(ep - sp)])
                res.append(new_ent)
        return Entities(res)

    def merge(self):
        merged_ents = []
        for ent in self.ents:
            if len(merged_ents) == 0:
                merged_ents.append(ent)
            elif (merged_ents[-1].end_pos == ent.start_pos and
                  merged_ents[-1].category == ent.category):
                merged_ent = Entity(ent_id=merged_ents[-1].ent_id,
                                    category=ent.category,
                                    start_pos=merged_ents[-1].start_pos,
                                    end_pos=ent.end_pos,
                                    text=merged_ents[-1].text + ent.text)
                merged_ents[-1] = merged_ent
            else:
                merged_ents.append(ent)
        return Entities(merged_ents)


class Document(object):
    def __init__(self, doc_id, text, ents=[]):
        self.doc_id = doc_id
        self.text = text
        self.ents = ents
        self.sents = self.extract_sentences()

    def extract_sentences(self):
        offset = 0
        ent_iter = iter(self.ents)
        ent = next(ent_iter, None)
        sents = []
        for text in self.text.split('。'):
            sent_ents = []
            while (ent is not None and
                   ent.start_pos >= offset and
                   ent.end_pos <= offset + len(text)):
                sent_ents.append(ent.offset(-offset))
                ent = next(ent_iter, None)
            sent = Sentence(self.doc_id, offset, text, sent_ents)
            sents.append(sent)
            offset += len(text) + 1
        return sents

    def pad(self, pad_left=0, pad_right=0, pad_val=" "):
        text = pad_left * pad_val + self.text + pad_right * pad_val
        ents = self.ents.offset(pad_left)
        return Document(self.doc_id, text, ents)

    def _repr_html_(self):
        sent = Sentence(self.doc_id, offset=0, text=self.text, ents=self.ents)
        return sent._repr_html_()


class Documents(object):
    def __init__(self, data_dir, doc_ids=None):
        self.data_dir = data_dir
        self.doc_ids = doc_ids
        if self.doc_ids is None:
            self.doc_ids = self.scan_doc_ids()

    def scan_doc_ids(self):
        doc_ids = [fname.split('.')[0] for fname in os.listdir(self.data_dir)]
        return np.unique(doc_ids)

    def read_txt_file(self, doc_id):
        fname = os.path.join(self.data_dir, doc_id + '.txt')
        with open(fname, encoding='utf-8') as f:
            text = f.read()
        return text

    def parse_entity_line(self, raw_str):
        ent_id, label, text = raw_str.strip().split('\t')
        category, pos = label.split(' ', 1)
        pos = pos.split(' ')
        ent = Entity(ent_id, category, int(pos[0]), int(pos[-1]), text)
        return ent

    def read_anno_file(self, doc_id):
        ents = []
        fname = os.path.join(self.data_dir, doc_id + '.ann')
        with open(fname, encoding='utf-8') as f:
            lines = f.readlines()

        for line in lines:
            if line.startswith('T'):
                ent = self.parse_entity_line(line)
                ents.append(ent)
        ents = Entities(ents)

        return ents

    def __len__(self):
        return len(self.doc_ids)

    def get_doc(self, doc_id):
        text = self.read_txt_file(doc_id)
        ents = self.read_anno_file(doc_id)
        doc = Document(doc_id, text, ents)
        return doc

    def __getitem__(self, key):
        if isinstance(key, int):
            doc_id = self.doc_ids[key]
            return self.get_doc(doc_id)
        if isinstance(key, str):
            doc_id = key
            return self.get_doc(doc_id)
        if isinstance(key, np.ndarray) and key.dtype == int:
            doc_ids = self.doc_ids[key]
            return Documents(self.data_dir, doc_ids=doc_ids)


class SentenceExtractor(object):
# 句子切分器,窗口为windows,两端分别延申pad_size
    def __init__(self, window_size=50, pad_size=10):
        self.window_size = window_size
        self.pad_size = pad_size
# 句子切分函数,切分的时候注意每个切分的句子相对于文档的偏移距离,预测的时候还需要还原
    def extract_doc(self, doc):
        num_sents = math.ceil(len(doc.text) / self.window_size)
        doc = doc.pad(pad_left=self.pad_size, pad_right=num_sents * self.window_size - len(doc.text) + self.pad_size)
        sents = []
        for cur_idx in range(self.pad_size, len(doc.text) - self.pad_size, self.window_size):
            sent_text = doc.text[cur_idx - self.pad_size: cur_idx + self.window_size + self.pad_size]
            ents = []
            for ent in doc.ents.find_entities(start_pos=cur_idx - self.pad_size,
                                              end_pos=cur_idx + self.window_size + self.pad_size):
                ents.append(ent.offset(-cur_idx + self.pad_size))
            sent = Sentence(doc.doc_id,
                            offset=cur_idx - 2 * self.pad_size,
                            text=sent_text,
                            ents=Entities(ents))
            sents.append(sent)
        return sents
# 内部函数:将类当成函数形式的调用
    def __call__(self, docs):
        sents = []
        for doc in docs:
            sents += self.extract_doc(doc)
        return sents


class Dataset(object):
    def __init__(self, sentences, word2idx=None, cate2idx=None):
        self.sentences = sentences
        self.word2idx = word2idx
        self.cate2idx = cate2idx

    def build_vocab_dict(self, vocab_size=2000):
        counter = Counter()
        for sent in self.sentences:
            for char in sent.text:
                counter[char] += 1
        word2idx = dict()
        word2idx['<unk>'] = 0
        if vocab_size > 0:
            num_most_common = vocab_size - len(word2idx)
        else:
            num_most_common = len(counter)
        for char, _ in counter.most_common(num_most_common):
            word2idx[char] = word2idx.get(char, len(word2idx))
        self.word2idx = word2idx

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sent_vec, labels_vec = [], []
        sents = self.sentences[idx]
        if not isinstance(sents, list):
            sents = [sents]
        for sent in sents:
            content = [self.word2idx.get(c, 1) for c in sent.text]
            sent_vec.append(content)
            labels_vec.append(sent.ents.vectorize(vec_len=len(sent.text), cate2idx=self.cate2idx))
        return np.array(sent_vec), np.expand_dims(np.array(labels_vec), axis=-1)


def make_sentence_prediction(pred, sent, idx2ent):
    ents_vec = np.argmax(pred, axis=1)
    ents = []
    cur_idx = 0
    for label_idx, group in groupby(ents_vec):
        group = list(group)
        start_pos = cur_idx
        end_pos = start_pos + len(group)
        if label_idx > 0:
            text = sent.text[start_pos: end_pos]
            category = idx2ent[label_idx]
            ent = Entity(None, category, start_pos, end_pos, text)
            ents.append(ent)
        cur_idx = end_pos
    return Sentence(sent.doc_id, sent.offset, sent.text, Entities(ents))


def make_doc_prediction(doc_id, sents, docs):
    sents = sorted(sents)
    ents = []
    for sent in sents:
        ents += sent.ents.offset(sent.offset)
    ents = Entities(ents).merge()

    for idx, ent in enumerate(ents):
        ent.ent_id = 'T{}'.format(idx + 1)
    doc = Document(doc_id, docs[doc_id].text, ents=ents)
    return doc

# 数据测试集命名实体识别预测输出
def make_predictions(preds, dataset, sent_pad, docs, idx2ent):
    pred_sents = []
    for sent, pred in zip(dataset.sentences, preds):
        pred_sent = make_sentence_prediction(pred, sent, idx2ent)
        pred_sent = pred_sent[sent_pad: -sent_pad]
        pred_sents.append(pred_sent)

    docs_sents = defaultdict(list)
    for sent in pred_sents:
        docs_sents[sent.doc_id].append(sent)

    pred_docs = dict()
    for doc_id, sentences in docs_sents.items():
        doc = make_doc_prediction(doc_id, sentences, docs)
        pred_docs[doc_id] = doc

    return pred_docs

;