难过的事情我要反复咀嚼,嚼到它再也不能困扰我半分
—— 25.3.13
一、配置文件 config.py
1.模型与数据路径
model_path:模型训练完成后保存的位置。例如:保存最终的模型权重文件。
schema_path:数据结构定义文件,通常用于描述数据的格式(如字段名、标签类型)。
在NER任务中,可能定义实体类别(如 {"PERSON": "人名", "ORG": "组织"}
)。
train_data_path:训练数据集路径,通常为标注好的文本文件(如 train.txt
或 JSON
格式)。
valid_data_path: 验证数据集路径,用于模型训练时的性能评估和超参数调优。
vocab_path:字符词汇表文件,记录模型中使用的字符集(如中文字符、字母、数字等)。
2.模型架构
max_length:输入文本的最大序列长度。超过此长度的文本会被截断或填充(如用 [PAD]
)。
hidden_size:模型隐藏层神经元的数量,影响模型容量和计算复杂度。
num_layers:模型的堆叠层数(如LSTM、Transformer的编码器/解码器层数)。
class_num:任务类别总数。例如:NER任务中可能有9种实体类型。
3.训练配置
epoch:训练轮数。每轮遍历整个训练数据集一次。
batch_size:每次梯度更新所使用的样本数量。较小的批次可能更适合内存受限的环境。
optimizer:优化器类型,用于调整模型参数。Adam是常用优化器,结合动量梯度下降。
learning_rate:学习率,控制参数更新的步长。值过小可能导致训练缓慢,过大易过拟合。
use_crf:是否启用条件随机场(CRF)层。在序列标注任务(如NER)中,CRF可捕捉标签间的依赖关系,提升准确性。
4.预训练模型
bert_path:预训练BERT模型的路径。BERT是一种强大的预训练语言模型,此处可能用于微调或特征提取。
# -*- coding: utf-8 -*-
"""
配置参数信息
"""
Config = {
"model_path": "model_output",
"schema_path": "ner_data/schema.json",
"train_data_path": "ner_data/train",
"valid_data_path": "ner_data/test",
"vocab_path":"chars.txt",
"max_length": 100,
"hidden_size": 256,
"num_layers": 2,
"epoch": 20,
"batch_size": 16,
"optimizer": "adam",
"learning_rate": 1e-3,
"use_crf": True,
"class_num": 9,
"bert_path": r"F:\人工智能NLP\\NLP资料\week6 语言模型\bert-base-chinese"
}
二、数据加载 loader.py
1.初始化数据加载类
data_path:数据文件存储路径
config:包含训练 / 数据配置的字典
self.config:保存包含训练 / 数据配置的字典
self.path:保存数据文件存储路径
self.vocab:加载字表 / 词表文件存储路径
self.sentences:初始化句子列表
self.schema:加载实体标签与索引的映射关系表
self.load:调用 load()
方法从 data_path
加载原始数据,进行分词、编码、填充/截断等预处理。
def __init__(self, data_path, config):
self.config = config
self.path = data_path
self.vocab = load_vocab(config["vocab_path"])
self.config["vocab_size"] = len(self.vocab)
self.sentences = []
self.schema = self.load_schema(config["schema_path"])
self.load()
2.加载数据并预处理
① 文件读取与分段:按段落分割原始数据。
② 逐行解析:提取字符和标签。
③ 编码转换:将字符转换为词汇表索引序列。
④ 序列标准化:调整序列长度至模型要求。
⑤ 数据存储:保存为张量列表,供训练使用。
self.data:列表,存储预处理后的数据样本,每个样本由输入张量和标签张量组成
sentenece:保存原始文本句子的拼接结果,便于后续可视化或调试。
open():打开文件并返回文件对象,支持读/写/追加等模式。
参数名 | 类型 | 说明 |
---|---|---|
file | 字符串 | 文件路径(绝对/相对路径) |
mode | 字符串 | 打开模式(如 r -只读、w -写入、a -追加) |
encoding | 字符串 | 文件编码(如 utf-8 ,文本模式需指定) |
errors | 字符串 | 编码错误处理方式(如 ignore 、replace ) |
文件对象.read():读取文件内容,返回字符串或字节流
参数名 | 类型 | 说明 |
---|---|---|
size | 整数 | 可选,指定读取的字节数(默认读取全部内容) |
split():按分隔符分割字符串,返回子字符串列表
参数名 | 类型 | 说明 |
---|---|---|
delimiter | 字符串 | 分隔符(默认空格) |
maxsplit | 整数 | 可选,最大分割次数(默认-1表示全部) |
strip():去除字符串首尾指定字符(默认空白字符)
参数名 | 类型 | 说明 |
---|---|---|
chars | 字符串 | 可选,指定需去除的字符集合 |
join():用分隔符连接可迭代对象的元素,返回新字符串
参数名 | 类型 | 说明 |
---|---|---|
iterable | 可迭代对象 | 需连接的元素集合(如列表、元组) |
sep | 字符串 | 分隔符(默认空字符串) |
列表.append():在列表末尾添加元素
参数名 | 类型 | 说明 |
---|---|---|
obj | 任意类型 | 要添加的元素 |
def load(self):
self.data = []
with open(self.path, encoding="utf8") as f:
segments = f.read().split("\n\n")
for segment in segments:
sentenece = []
labels = []
for line in segment.split("\n"):
if line.strip() == "":
continue
char, label = line.split()
sentenece.append(char)
labels.append(self.schema[label])
self.sentences.append("".join(sentenece))
input_ids = self.encode_sentence(sentenece)
labels = self.padding(labels, -1)
self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])
return
3.加载字 / 词表
load_vocab 函数用于从指定路径加载词汇表文件,并将每个词汇项映射到一个从 1 开始的唯一整数索引(索引 0 保留给 Padding 占位符)
token_dict:字典,存储词汇到索引的映射
open():打开文件并返回文件对象,用于读写文件内容
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
file_name | str | 无 | 文件路径(需包含扩展名) |
mode | str | 'r' | 文件打开模式: - 'r' : 只读- 'w' : 只写(覆盖原文件)- 'a' : 追加写入- 'b' : 二进制模式- 'x' : 创建新文件(若存在则报错) |
buffering | int | None | 缓冲区大小(仅二进制模式有效) |
encoding | str | None | 文件编码(仅文本模式有效,如 'utf-8' ) |
newline | str | '\n' | 行结束符(仅文本模式有效) |
closefd | bool | True | 是否在文件关闭时自动关闭文件描述符 |
dir_fd | int | -1 | 文件描述符(高级用法,通常忽略) |
flags | int | 0 | Linux 系统下的额外标志位 |
mode | str | 无 | (重复参数,实际使用中只需指定 mode ) |
enumerate():遍历可迭代对象时,同时返回元素的索引和值。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
iterable | 可迭代对象 | 无 | 需要遍历的对象(如列表、元组、字符串等) |
start | int | 0 | 索引的起始值(可自定义,如从 1 开始) |
strip():移除字符串开头和结尾的空白字符或指定字符
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
chars | str | None | 需要移除的字符集合(默认为空格、换行、制表符 \t 、换页符 \f 、回车 \r ) |
#加载字表或词表
def load_vocab(vocab_path):
token_dict = {}
with open(vocab_path, encoding="utf8") as f:
for index, line in enumerate(f):
token = line.strip()
token_dict[token] = index + 1 #0留给padding位置,所以从1开始
return token_dict
4.加载映射关系表
加载位于指定路径的 JSON 格式的模式文件,并将其内容解析为 Python 对象以便在数据生成过程中使用。
open():打开文件并返回文件对象,用于读写文件内容。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
file_name | str | 无 | 文件路径(需包含扩展名) |
mode | str | 'r' | 文件打开模式: - 'r' : 只读- 'w' : 只写(覆盖原文件)- 'a' : 追加写入- 'b' : 二进制模式- 'x' : 创建新文件(若存在则报错) |
buffering | int | None | 缓冲区大小(仅二进制模式有效) |
encoding | str | None | 文件编码(仅文本模式有效,如 'utf-8' ) |
newline | str | '\n' | 行结束符(仅文本模式有效) |
closefd | bool | True | 是否在文件关闭时自动关闭文件描述符 |
dir_fd | int | -1 | 文件描述符(高级用法,通常忽略) |
flags | int | 0 | Linux 系统下的额外标志位 |
mode | str | 无 | (重复参数,实际使用中只需指定 mode ) |
json.load():从已打开的 JSON 文件对象中加载数据,并将其转换为 Python 对象(如字典、列表)。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
fp | io.TextIO | 无 | 已打开的文件对象(需处于读取模式) |
indent | int/str | None | 缩进空格数(美化输出,如 4 或 " " ) |
sort_keys | bool | False | 是否对 JSON 键进行排序 |
load_hook | callable | None | 自定义对象加载回调函数 |
object_hook | callable | None | 自定义对象解析回调函数 |
def load_schema(self, path):
with open(path, encoding="utf8") as f:
return json.load(f)
5.封装数据
DataLoader():PyTorch 模型训练的标配工具,通过合理的参数配置(如 batch_size
、num_workers
、shuffle
),可以显著提升数据加载效率,尤其适用于大规模数据集和复杂预处理任务。其与 Dataset
类的配合使用,是构建高效训练管道的核心。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
dataset | Dataset | None | 必须参数,自定义数据集对象(需继承 torch.utils.data.Dataset )。 |
batch_size | int | 1 | 每个批次的样本数量。 |
shuffle | bool | False | 是否在每个 epoch 开始时打乱数据顺序(训练时推荐设为 True )。 |
num_workers | int | 0 | 使用多线程加载数据的工人数量(需大于 0 时生效)。 |
pin_memory | bool | False | 是否将数据存储在 pinned memory 中(加速 GPU 数据传输)。 |
drop_last | bool | False | 如果数据集长度无法被 batch_size 整除,是否丢弃最后一个不完整的批次。 |
persistent_workers | bool | False | 是否保持工作线程在 epoch 之间持续运行(减少多线程初始化开销)。 |
worker_init_fn | callable | None | 自定义工作线程初始化函数。 |
#用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):
dg = DataGenerator(data_path, config)
dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)
return dl
6.对于输入文本做截断 / 填充
#补齐或截断输入的序列,使其可以在一个batch内运算
def padding(self, input_id, pad_token=0):
input_id = input_id[:self.config["max_length"]]
input_id += [pad_token] * (self.config["max_length"] - len(input_id))
return input_id
7.对于输入的文本编码
输入:原始文本
text
(字符串),padding
标志(布尔值,决定是否填充)初始化:创建空列表
input_id
存储编码后的索引序列分词/字符处理分支:
- 词级别处理(
"words.txt"
):
使用结巴分词(jieba.cut
)将文本切分为词语,遍历每个词语,查询词汇表self.vocab
:
- 若词语存在 → 添加其索引
- 若不存在 → 使用
[UNK]
(未知词)的索引- 字符级别处理(其他情况):
直接遍历文本的每个字符,查询词汇表self.vocab
:
- 若字符存在 → 添加其索引
- 若不存在 → 使用
[UNK]
的索引条件执行:若
padding=True
,调用self.padding
方法对input_id
进行填充返回:整数列表
input_id
,表示文本的编码序列
input_id:初始化列表,存储词 / 字符的索引
jieba.cut():将中文句子分割成词语,支持三种分词模式(精确模式、全模式、搜索引擎模式)
参数名 | 类型 | 说明 |
---|---|---|
sentence | 字符串 | 需要分词的中文句子 |
cut_all | 布尔值 | 是否采用全模式(True为全模式,False为精确模式,默认False) |
HMM | 布尔值 | 是否使用隐马尔可夫模型(True为使用,默认True) |
列表.append():在列表末尾添加一个元素,修改原列表
参数名 | 类型 | 说明 |
---|---|---|
obj | 任意类型 | 要添加的元素(支持字符串、数字、列表等) |
字典.get():安全获取字典中指定键的值,键不存在时返回默认值(默认为None
)
参数名 | 类型 | 说明 |
---|---|---|
key | 不可变类型 | 要查询的键 |
default | 任意类型 | 可选,键不存在时返回的默认值(若未指定则返回None) |
def encode_sentence(self, text, padding=True):
input_id = []
if self.config["vocab_path"] == "words.txt":
for word in jieba.cut(text):
input_id.append(self.vocab.get(word, self.vocab["[UNK]"]))
else:
for char in text:
input_id.append(self.vocab.get(char, self.vocab["[UNK]"]))
if padding:
input_id = self.padding(input_id)
return input_id
完整代码
# -*- coding: utf-8 -*-
import json
import re
import os
import torch
import random
import jieba
import numpy as np
from torch.utils.data import Dataset, DataLoader
"""
数据加载
"""
class DataGenerator:
def __init__(self, data_path, config):
self.config = config
self.path = data_path
self.vocab = load_vocab(config["vocab_path"])
self.config["vocab_size"] = len(self.vocab)
self.sentences = []
self.schema = self.load_schema(config["schema_path"])
self.load()
def load(self):
self.data = []
with open(self.path, encoding="utf8") as f:
segments = f.read().split("\n\n")
for segment in segments:
sentenece = []
labels = []
for line in segment.split("\n"):
if line.strip() == "":
continue
char, label = line.split()
sentenece.append(char)
labels.append(self.schema[label])
self.sentences.append("".join(sentenece))
input_ids = self.encode_sentence(sentenece)
labels = self.padding(labels, -1)
self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])
return
def encode_sentence(self, text, padding=True):
input_id = []
if self.config["vocab_path"] == "words.txt":
for word in jieba.cut(text):
input_id.append(self.vocab.get(word, self.vocab["[UNK]"]))
else:
for char in text:
input_id.append(self.vocab.get(char, self.vocab["[UNK]"]))
if padding:
input_id = self.padding(input_id)
return input_id
#补齐或截断输入的序列,使其可以在一个batch内运算
def padding(self, input_id, pad_token=0):
input_id = input_id[:self.config["max_length"]]
input_id += [pad_token] * (self.config["max_length"] - len(input_id))
return input_id
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
def load_schema(self, path):
with open(path, encoding="utf8") as f:
return json.load(f)
#加载字表或词表
def load_vocab(vocab_path):
token_dict = {}
with open(vocab_path, encoding="utf8") as f:
for index, line in enumerate(f):
token = line.strip()
token_dict[token] = index + 1 #0留给padding位置,所以从1开始
return token_dict
#用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):
dg = DataGenerator(data_path, config)
dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)
return dl
if __name__ == "__main__":
from config import Config
dg = DataGenerator("../ner_data/train.txt", Config)
三、模型建立 model.py
1.模型初始化
hidden_size:定义LSTM隐藏层的维度(即每个时间步输出的特征数量)
vocab_size:词表大小,即嵌入层(Embedding)可处理的词汇总数
max_length:输入序列的最大长度,用于数据预处理(如截断或填充)
class_num:分类任务的类别数量,决定线性层(nn.Linear
)的输出维度
num_layers:堆叠的LSTM层数,用于增加模型复杂度
nn.Embedding():将离散的索引映射为稠密向量(如词嵌入)
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
num_embeddings | 整数 | 无 | 词表大小(如 vocab_size + 1 ) |
embedding_dim | 整数 | 无 | 嵌入向量维度(如 hidden_size ) |
padding_idx | 整数 | None | 指定填充符索引(如 0 ),该位置的梯度不更新 |
nn.LSTM():长短期记忆网络(LSTM),用于序列建模。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
input_size | 整数 | 无 | 输入特征维度(如嵌入层输出维度 hidden_size ) |
hidden_size | 整数 | 无 | 隐藏状态维度(决定模型容量) |
num_layers | 整数 | 1 | LSTM 堆叠层数(多层时上一层的输出作为下一层的输入) |
batch_first | 布尔值 | False | 输入张量是否为 (batch_size, seq_len, input_size) 格式 |
bidirectional | 布尔值 | False | 是否启用双向 LSTM(输出维度变为 hidden_size * 2 ) |
nn.Linear():实现全连接层的线性变换(y = xW^T + b
)
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
in_features | 整数 | 无 | 输入特征维度(如词向量维度 hidden_size ) |
out_features | 整数 | 无 | 输出特征维度(如分类类别数 class_num ) |
bias | 布尔值 | True | 是否启用偏置项 |
CRF():条件随机场层,用于序列标注任务中约束标签转移逻辑。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
num_tags | 整数 | 无 | 标签类别数(如 class_num ) |
batch_first | 布尔值 | False | 输入张量是否为 (batch_size, seq_len) 格式 |
torch.nn.CrossEntropyLoss():计算交叉熵损失,常用于分类任务。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
ignore_index | 整数 | -1 | 忽略指定索引的标签(如填充符 -1 ) |
reduction | 字符串 | mean | 损失聚合方式(可选 none 、sum 、mean ) |
def __init__(self, config):
super(TorchModel, self).__init__()
hidden_size = config["hidden_size"]
vocab_size = config["vocab_size"] + 1
max_length = config["max_length"]
class_num = config["class_num"]
num_layers = config["num_layers"]
self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)
self.classify = nn.Linear(hidden_size * 2, class_num)
self.crf_layer = CRF(class_num, batch_first=True)
self.use_crf = config["use_crf"]
self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1) #loss采用交叉熵损失
2.前向计算
输入 x → 嵌入层 → 序列层 → 分类层 → 预测值 → 分支判断:
│
├── 存在 target → CRF? → 是:计算 CRF 损失(带掩码)
│ │
│ └→ 否:计算交叉熵损失(展平处理)
│
└── 无 target → CRF? → 是:解码最优标签序列
│
└→ 否:直接返回预测 logits
gt():张量的逐元素比较函数,返回布尔型张量,标记输入张量中大于指定值的元素位置。常用于生成掩码(如忽略填充符)
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
other | Tensor/标量 | 无 | 比较的阈值或张量。若为标量,则张量中每个元素与该值比较;若为张量,需与输入张量形状相同。 |
out | Tensor | None | 可选输出张量,用于存储结果。 |
shape():返回张量的维度信息,描述各轴的大小。
view():调整张量的形状,支持自动推断维度(通过-1
占位符)。常用于数据展平或维度转换。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
*shape | 可变参数 | 无 | 目标形状的维度序列,如view(2, 3) 或view(-1, 28) ,-1 表示自动计算。 |
#当输入真实标签,返回loss值;无真实标签,返回预测值
def forward(self, x, target=None):
x = self.embedding(x) #input shape:(batch_size, sen_len)
x, _ = self.layer(x) #input shape:(batch_size, sen_len, input_dim)
predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)
if target is not None:
if self.use_crf:
mask = target.gt(-1)
return - self.crf_layer(predict, target, mask, reduction="mean")
else:
#(number, class_num), (number)
return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))
else:
if self.use_crf:
return self.crf_layer.decode(predict)
else:
return predict
3.选择优化器
Adam():自适应矩估计优化器(Adaptive Moment Estimation),结合动量和 RMSProp 的优点。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
lr | float | 1e-3 | 学习率。 |
betas | tuple | (0.9, 0.999) | 动量系数(β₁, β₂)。 |
eps | float | 1e-8 | 防止除零误差。 |
weight_decay | float | 0 | 权重衰减率。 |
amsgrad | bool | False | 是否启用 AMSGrad 优化。 |
foreach | bool | False | 是否为每个参数单独计算梯度。 |
SGD():随机梯度下降优化器(Stochastic Gradient Descent)
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
lr | float | 1e-3 | 学习率。 |
momentum | float | 0 | 动量系数(如 momentum=0.9 )。 |
weight_decay | float | 0 | 权重衰减率。 |
dampening | float | 0 | 动力衰减系数(用于 SGD with Momentum)。 |
nesterov | bool | False | 是否启用 Nesterov 动量。 |
foreach | bool | False | 是否为每个参数单独计算梯度。 |
parameters():返回模型所有可训练参数的迭代器,常用于参数初始化或梯度清零。
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
filter | callable | None | 过滤条件函数(如 lambda p: p.requires_grad )。默认返回所有参数。 |
def choose_optimizer(config, model):
optimizer = config["optimizer"]
learning_rate = config["learning_rate"]
if optimizer == "adam":
return Adam(model.parameters(), lr=learning_rate)
elif optimizer == "sgd":
return SGD(model.parameters(), lr=learning_rate)
4.模型建立
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torchcrf import CRF
"""
建立网络模型结构
"""
class TorchModel(nn.Module):
def __init__(self, config):
super(TorchModel, self).__init__()
hidden_size = config["hidden_size"]
vocab_size = config["vocab_size"] + 1
max_length = config["max_length"]
class_num = config["class_num"]
num_layers = config["num_layers"]
self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)
self.classify = nn.Linear(hidden_size * 2, class_num)
self.crf_layer = CRF(class_num, batch_first=True)
self.use_crf = config["use_crf"]
self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1) #loss采用交叉熵损失
#当输入真实标签,返回loss值;无真实标签,返回预测值
def forward(self, x, target=None):
x = self.embedding(x) #input shape:(batch_size, sen_len)
x, _ = self.layer(x) #input shape:(batch_size, sen_len, input_dim)
predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)
if target is not None:
if self.use_crf:
mask = target.gt(-1)
return - self.crf_layer(predict, target, mask, reduction="mean")
else:
#(number, class_num), (number)
return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))
else:
if self.use_crf:
return self.crf_layer.decode(predict)
else:
return predict
def choose_optimizer(config, model):
optimizer = config["optimizer"]
learning_rate = config["learning_rate"]
if optimizer == "adam":
return Adam(model.parameters(), lr=learning_rate)
elif optimizer == "sgd":
return SGD(model.parameters(), lr=learning_rate)
if __name__ == "__main__":
from config import Config
model = TorchModel(Config)
四、模型效果评估 evaluate.py
1.模型流程
Ⅰ、数据准备与初始化
加载验证集:从指定路径加载预处理后的验证数据,保持数据顺序以避免随机性干扰评估结果
初始化统计字典:为每个实体类别(如LOCATION、PERSON等)创建计数器,记录“正确识别数”“样本实体数”等统计指标
Ⅱ、模型推理与预测
切换评估模式:调用
model.eval()
关闭Dropout等训练层,确保推理稳定性批次处理:
数据迁移至GPU:若CUDA可用,将输入ID和标签移至GPU加速计算
无梯度预测:在
torch.no_grad()
上下文中执行模型推理,减少内存占用输出处理:若未使用CRF层,通过
torch.argmax
直接获取预测标签序列;若使用CRF,需解码最优路径
Ⅲ、实体解码与对齐
标签序列转换:将数值标签拼接为字符串(如
[0,4,4]
→"044"
),并截取与句子长度对齐正则匹配实体:
规则定义:通过正则表达式匹配标签模式(如
04+
表示LOCATION实体),提取连续B-I标签对应的文本片段索引对齐:根据匹配的起止位置从原始句子中截取实体(例如
"04+"
匹配到索引3-5,则提取句子[3:5]
)
Ⅳ、统计与评估指标计算
对比真实与预测实体:遍历每个句子的实体列表,统计以下指标:
正确识别数:预测实体存在于真实列表中的数量。
样本实体数:真实实体总数。
识别出实体数:预测实体总数
计算指标:
精确率(Precision):正确识别数 / 识别出实体数。
召回率(Recall):正确识别数 / 样本实体数。
F1值:精确率与召回率的调和平均
输出结果:按实体类别输出指标,并计算宏平均(Macro-F1)和微平均(Micro-F1)
Ⅴ、关键设计细节
标签编码规则:采用BIO格式(如B-LOCATION=0,I-LOCATION=4),确保实体连续性
异常处理:添加
1e-5
平滑项避免除零错误,增强数值稳定性性能优化:禁用梯度计算、GPU加速、批次处理提升效率
2.初始化
Ⅰ、加载配置文件、模型及日志模块 ——>
Ⅱ、读取验证集数据(固定顺序,避免随机性干扰评估)——>
Ⅲ、初始化统计字典
stats_dict
,按实体类别记录正确识别数、样本实体数等
config:存储运行时配置,例如数据路径、超参数(如批次大小 batch_size
)、是否使用CRF层等。通过 config["valid_data_path"]
动态获取验证集路径。
model:待评估的模型实例,用于调用预测方法(如 model(input_id)
),需提前完成训练和加载。
logger:记录运行日志,例如输出评估指标(准确率、F1值)到文件或控制台,便于调试和监控。
valid_data:验证数据集,用于模型训练时的性能评估和超参数调优。
load_data():数据加载类中,用torch自带的DataLoader类封装数据的函数
def __init__(self, config, model, logger):
self.config = config
self.model = model
self.logger = logger
self.valid_data = load_data(config["valid_data_path"], config, shuffle=False)
3.统计模型效果
Ⅰ、解码实体 ——> Ⅱ、对比结果
len():返回对象的元素数量(字符串、列表、元组、字典等)
参数名 | 类型 | 说明 |
---|---|---|
object | 任意可迭代对象 | 如字符串、列表、字典等 |
torch.argmax():返回张量中最大值所在的索引
参数名 | 类型 | 说明 |
---|---|---|
input | Tensor | 输入张量 |
dim | int | 沿指定维度查找最大值 |
keepdim | bool | 是否保持输出维度一致 |
cpu():将张量从GPU移动到CPU内存
zip():将多个可迭代对象打包成元组列表
参数名 | 类型 | 说明 |
---|---|---|
iterables | 多个可迭代对象 | 如列表、元组、字符串 |
.detach():从计算图中分离张量,阻止梯度传播
.tolist():将张量或数组转换为Python列表
def write_stats(self, labels, pred_results, sentences):
assert len(labels) == len(pred_results) == len(sentences)
if not self.config["use_crf"]:
pred_results = torch.argmax(pred_results, dim=-1)
for true_label, pred_label, sentence in zip(labels, pred_results, sentences):
if not self.config["use_crf"]:
pred_label = pred_label.cpu().detach().tolist()
true_label = true_label.cpu().detach().tolist()
true_entities = self.decode(sentence, true_label)
pred_entities = self.decode(sentence, pred_label)
# print("=+++++++++")
# print(true_entities)
# print(pred_entities)
# print('=+++++++++')
# 正确率 = 识别出的正确实体数 / 识别出的实体数
# 召回率 = 识别出的正确实体数 / 样本的实体数
for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
self.stats_dict[key]["正确识别"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])
self.stats_dict[key]["样本实体数"] += len(true_entities[key])
self.stats_dict[key]["识别出实体数"] += len(pred_entities[key])
return
4.可视化统计模型效果
精确率 (Precision):正确预测实体数 / 总预测实体数
召回率 (Recall):正确预测实体数 / 总真实实体数
F1值:精确率与召回率的调和平均
F1:F1分数:准确率与召回率的调和平均数,综合衡量模型的精确性与覆盖能力。
F1_scores:存储四个实体类别的 F1 分数,用于计算宏观平均。
precision:准确率:模型预测为某类实体的结果中,正确的比例。反映模型预测的精确度。
recall:召回率:真实存在的某类实体中,被模型正确识别的比例。反映模型对实体的覆盖能力。
key:当前处理的实体类别(如 "PERSON"
、"LOCATION"
)。
correct_pred:总正确识别数:所有类别中被正确识别的实体总数。
total_pred:总识别实体数:模型预测出的所有实体数量(含错误识别)。
true_enti:总样本实体数:验证数据中真实存在的所有实体数量。
micro_precision:微观准确率:全局视角下的准确率,所有实体类别的正确识别数与总识别数的比例。
micro_recall:微观召回率:全局视角下的召回率,所有实体类别的正确识别数与总样本实体数的比例。
micro_f1:微观F1分数:微观准确率与微观召回率的调和平均数。
列表.append():在列表末尾添加元素
参数名 | 类型 | 说明 |
---|---|---|
element | 任意 | 要添加的元素 |
logger.info():记录日志信息(需配置日志模块)
参数名 | 类型 | 说明 |
---|---|---|
format | str | 格式化字符串 |
*args | 可变参数 | 格式化参数 |
sum():计算可迭代对象的元素总和
参数名 | 类型 | 说明 |
---|---|---|
iterable | 可迭代对象 | 如列表、元组 |
start | 数值(可选) | 初始累加值 |
列表推导式:通过简洁语法生成新列表,语法:[表达式 for item in iterable if 条件]
def show_stats(self):
F1_scores = []
for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
# 正确率 = 识别出的正确实体数 / 识别出的实体数
# 召回率 = 识别出的正确实体数 / 样本的实体数
precision = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["识别出实体数"])
recall = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["样本实体数"])
F1 = (2 * precision * recall) / (precision + recall + 1e-5)
F1_scores.append(F1)
self.logger.info("%s类实体,准确率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))
self.logger.info("Macro-F1: %f" % np.mean(F1_scores))
correct_pred = sum([self.stats_dict[key]["正确识别"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
total_pred = sum([self.stats_dict[key]["识别出实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
true_enti = sum([self.stats_dict[key]["样本实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
micro_precision = correct_pred / (total_pred + 1e-5)
micro_recall = correct_pred / (true_enti + 1e-5)
micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)
self.logger.info("Micro-F1 %f" % micro_f1)
self.logger.info("--------------------")
return
5.评估模型效果
模型切换为评估模式:关闭Dropout等训练层
批次处理数据:
提取原始句子
sentences
将数据迁移至GPU(若可用)
预测时禁用梯度计算(
torch.no_grad()
)优化内存统计结果:调用
write_stats
对比预测与真实标签
epoch:当前训练轮次,用于日志。
logger:记录日志的工具。
stats_dict:统计字典,记录各实体类别的指标。
valid_data:验证数据集,通常由 load_data
加载(如 config["valid_data_path"]
指定路径)
index:循环中的批次索引
batch_data:循环中的数据。
sentences:当前批次的原始句子
pred_results:模型预测结果
write_stats():写入统计信息
show_stats():显示统计结果
logger.info():记录日志信息(需配置日志模块)
参数名 | 类型 | 说明 |
---|---|---|
format | str | 格式化字符串 |
*args | 可变参数 | 格式化参数 |
defaultdict():创建带有默认值工厂的字典
参数名 | 类型 | 说明 |
---|---|---|
default_factory | 可调用对象 | 如int、list、自定义函数 |
model.eval():将模型设置为评估模式(关闭Dropout等训练层)
enumerate():返回索引和元素组成的枚举对象
参数名 | 类型 | 说明 |
---|---|---|
iterable | 可迭代对象 | 如列表、字符串 |
start | int(可选) | 起始索引,默认为0 |
torch.cuda.is_available():检查当前环境是否支持CUDA(GPU加速)
cuda():将张量或模型移动到GPU
参数名 | 类型 | 说明 |
---|---|---|
device | int/str | 指定GPU设备号,如"cuda:0" |
torch.no_grad():禁用梯度计算,节省内存并加速推理
def eval(self, epoch):
self.logger.info("开始测试第%d轮模型效果:" % epoch)
self.stats_dict = {"LOCATION": defaultdict(int),
"TIME": defaultdict(int),
"PERSON": defaultdict(int),
"ORGANIZATION": defaultdict(int)}
self.model.eval()
for index, batch_data in enumerate(self.valid_data):
sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]
if torch.cuda.is_available():
batch_data = [d.cuda() for d in batch_data]
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
with torch.no_grad():
pred_results = self.model(input_id) #不输入labels,使用模型当前参数进行预测
self.write_stats(labels, pred_results, sentences)
self.show_stats()
return
6.解码
标签序列预处理:将数值标签拼接为字符串(如
[0,4,4]
→"044"
)正则匹配实体:
04+
:B-LOCATION(0)后接多个I-LOCATION(4)
15+
:B-ORGANIZATION(1)后接I-ORGANIZATION(5)其他实体类别同理
索引对齐:根据匹配位置截取原始句子中的实体文本
Ⅰ、输入预处理
在原句首添加 $
符号,通常用于对齐标签与字符位置(例如避免索引越界)
sentence = "$" + sentence
Ⅱ、标签序列转换
将整数标签序列转换为字符串,并截取长度与 sentence
对齐
str.join():将可迭代对象中的字符串元素按指定分隔符连接成一个新字符串
参数名 | 类型 | 说明 |
---|---|---|
iterable | 可迭代对象 | 元素必须为字符串类型 |
str():将对象转换为字符串表示形式,支持自定义类的 __str__
方法
参数名 | 类型 | 说明 |
---|---|---|
object | 任意 | 要转换的对象 |
len():返回对象的长度或元素个数(适用于字符串、列表、字典等)
参数名 | 类型 | 说明 |
---|---|---|
object | 可迭代对象 | 如字符串、列表等 |
列表推导式:通过简洁语法生成新列表,支持条件过滤和多层循环
[expression for item in iterable if condition]
部分 | 类型 | 说明 |
---|---|---|
expression | 表达式 | 对 item 处理后的结果 |
item | 变量 | 迭代变量 |
iterable | 可迭代对象 | 如列表、range() 生成的序列 |
condition | 条件表达式 (可选) | 过滤不符合条件的元素 |
labels = "".join([str(x) for x in labels[:len(sentence)+1]])
Ⅲ、 初始化结果容器
创建默认值为列表的字典,存储四类实体(LOCATION、ORGANIZATION、PERSON、TIME)的识别结果
defaultdict():创建默认值字典,当键不存在时自动生成默认值(基于工厂函数)
参数名 | 类型 | 说明 |
---|---|---|
default_factory | 可调用对象 | 如 int 、list 或自定义函数 |
results = defaultdict(list)
Ⅳ、 正则表达式匹配
(04+)
: 匹配以0
(B-LOCATION)开头,后接多个4
(I-LOCATION)的连续标签
(15+)
、(26+)
、(37+)
:分别对应 ORGANIZATION(B=1, I=5)、PERSON(B=2, I=6)、TIME(B=3, I=7)的标签模式。
re.finditer():在字符串中全局搜索正则表达式匹配项,返回一个迭代器,每个元素为 Match
对象
参数名 | 类型 | 说明 |
---|---|---|
pattern | str 或正则表达式对象 | 要匹配的正则表达式模式 |
string | str | 要搜索的字符串 |
flags | int (可选) | 正则匹配标志(如 re.IGNORECASE ) |
.span():返回正则匹配的起始和结束索引(左闭右开区间)
列表.append():向列表末尾添加单个元素,直接修改原列表
参数名 | 类型 | 说明 |
---|---|---|
element | 任意 | 要添加的元素 |
for location in re.finditer("(04+)", labels):
s, e = location.span()
results["LOCATION"].append(sentence[s:e])
Ⅴ、完整代码
'''
{
"B-LOCATION": 0,
"B-ORGANIZATION": 1,
"B-PERSON": 2,
"B-TIME": 3,
"I-LOCATION": 4,
"I-ORGANIZATION": 5,
"I-PERSON": 6,
"I-TIME": 7,
"O": 8
}
'''
def decode(self, sentence, labels):
labels = "".join([str(x) for x in labels[:len(sentence)]])
results = defaultdict(list)
for location in re.finditer("(04+)", labels):
s, e = location.span()
results["LOCATION"].append(sentence[s:e])
for location in re.finditer("(15+)", labels):
s, e = location.span()
results["ORGANIZATION"].append(sentence[s:e])
for location in re.finditer("(26+)", labels):
s, e = location.span()
results["PERSON"].append(sentence[s:e])
for location in re.finditer("(37+)", labels):
s, e = location.span()
results["TIME"].append(sentence[s:e])
return results
7.完整代码
# -*- coding: utf-8 -*-
import torch
import re
import numpy as np
from collections import defaultdict
from loader import load_data
"""
模型效果测试
"""
class Evaluator:
def __init__(self, config, model, logger):
self.config = config
self.model = model
self.logger = logger
self.valid_data = load_data(config["valid_data_path"], config, shuffle=False)
def eval(self, epoch):
self.logger.info("开始测试第%d轮模型效果:" % epoch)
self.stats_dict = {"LOCATION": defaultdict(int),
"TIME": defaultdict(int),
"PERSON": defaultdict(int),
"ORGANIZATION": defaultdict(int)}
self.model.eval()
for index, batch_data in enumerate(self.valid_data):
sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]
if torch.cuda.is_available():
batch_data = [d.cuda() for d in batch_data]
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
with torch.no_grad():
pred_results = self.model(input_id) #不输入labels,使用模型当前参数进行预测
self.write_stats(labels, pred_results, sentences)
self.show_stats()
return
def write_stats(self, labels, pred_results, sentences):
assert len(labels) == len(pred_results) == len(sentences)
if not self.config["use_crf"]:
pred_results = torch.argmax(pred_results, dim=-1)
for true_label, pred_label, sentence in zip(labels, pred_results, sentences):
if not self.config["use_crf"]:
pred_label = pred_label.cpu().detach().tolist()
true_label = true_label.cpu().detach().tolist()
true_entities = self.decode(sentence, true_label)
pred_entities = self.decode(sentence, pred_label)
# print("=+++++++++")
# print(true_entities)
# print(pred_entities)
# print('=+++++++++')
# 正确率 = 识别出的正确实体数 / 识别出的实体数
# 召回率 = 识别出的正确实体数 / 样本的实体数
for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
self.stats_dict[key]["正确识别"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])
self.stats_dict[key]["样本实体数"] += len(true_entities[key])
self.stats_dict[key]["识别出实体数"] += len(pred_entities[key])
return
def show_stats(self):
F1_scores = []
for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:
# 正确率 = 识别出的正确实体数 / 识别出的实体数
# 召回率 = 识别出的正确实体数 / 样本的实体数
precision = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["识别出实体数"])
recall = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["样本实体数"])
F1 = (2 * precision * recall) / (precision + recall + 1e-5)
F1_scores.append(F1)
self.logger.info("%s类实体,准确率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))
self.logger.info("Macro-F1: %f" % np.mean(F1_scores))
correct_pred = sum([self.stats_dict[key]["正确识别"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
total_pred = sum([self.stats_dict[key]["识别出实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
true_enti = sum([self.stats_dict[key]["样本实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])
micro_precision = correct_pred / (total_pred + 1e-5)
micro_recall = correct_pred / (true_enti + 1e-5)
micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)
self.logger.info("Micro-F1 %f" % micro_f1)
self.logger.info("--------------------")
return
'''
{
"B-LOCATION": 0,
"B-ORGANIZATION": 1,
"B-PERSON": 2,
"B-TIME": 3,
"I-LOCATION": 4,
"I-ORGANIZATION": 5,
"I-PERSON": 6,
"I-TIME": 7,
"O": 8
}
'''
def decode(self, sentence, labels):
labels = "".join([str(x) for x in labels[:len(sentence)]])
results = defaultdict(list)
for location in re.finditer("(04+)", labels):
s, e = location.span()
results["LOCATION"].append(sentence[s:e])
for location in re.finditer("(15+)", labels):
s, e = location.span()
results["ORGANIZATION"].append(sentence[s:e])
for location in re.finditer("(26+)", labels):
s, e = location.span()
results["PERSON"].append(sentence[s:e])
for location in re.finditer("(37+)", labels):
s, e = location.span()
results["TIME"].append(sentence[s:e])
return results
五、主函数文件 main.py
① 环境初始化与配置加载 ——>
② 数据加载与预处理 ——>
③ 模型初始化与硬件适配 ——>
④ 优化器与评估器初始化 ——>
⑤ 训练循环与参数更新 ——>
⑥ 模型评估与权重保存
1.导入文件
# -*- coding: utf-8 -*-
import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data
2.日志配置
logging.basicConfig():配置日志系统的基础参数(一次性设置,应在首次日志调用前调用)
参数名 | 类型 | 是否必需 | 默认值 | 说明 |
---|---|---|---|---|
filename | 字符串 | 否 | None | 日志输出文件名(若指定,日志写入文件而非控制台) |
filemode | 字符串 | 否 | 'a' | 文件打开模式(如'w' 覆盖,'a' 追加) |
format | 字符串 | 否 | 基础格式 | 日志格式模板(如'%(asctime)s - %(levelname)s - %(message)s' ) |
datefmt | 字符串 | 否 | 无 | 时间格式(如'%Y-%m-%d %H:%M:%S' ) |
level | 整数 | 否 | WARNING | 日志级别(如logging.INFO 、logging.DEBUG ) |
stream | 对象 | 否 | None | 指定日志输出流(如sys.stderr ,与filename 互斥) |
logging.getLogger():获取或创建指定名称的日志记录器(Logger)。若name
为None
,返回根日志记录器
参数名 | 类型 | 是否必需 | 默认值 | 说明 |
---|---|---|---|---|
name | 字符串 | 否 | None | 日志记录器名称(分层结构,如'module.sub' ) |
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
3.主函数 main
Ⅰ、创建模型保存目录
os.path.isdir():检查指定路径是否为目录(文件夹)
参数名 | 类型 | 是否必需 | 默认值 | 说明 |
---|---|---|---|---|
path | 字符串 | 是 | 无 | 要检查的路径(绝对或相对) |
os.mkdir():创建单个目录(若父目录不存在会抛出异常)
参数名 | 类型 | 是否必需 | 默认值 | 说明 |
---|---|---|---|---|
path | 字符串 | 是 | 无 | 要创建的目录路径 |
mode | 整数 | 否 | 0o777 | 目录权限(八进制格式,某些系统可能忽略此参数) |
#创建保存模型的目录
if not os.path.isdir(config["model_path"]):
os.mkdir(config["model_path"])
Ⅱ、加载训练数据
#加载训练数据
train_data = load_data(config["train_data_path"], config)
Ⅲ、加载模型
#加载模型
model = TorchModel(config)
Ⅳ、检查GPU并迁移模型
torch.cuda.is_available():检查系统是否满足 CUDA 环境要求
logger.info():记录日志信息,输出训练过程中的关键状态
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
msg | str | 是 | 日志消息(支持格式化字符串) | logger.info("Epoch: %d", epoch) |
*args | Any | 否 | 格式化参数(用于% 占位符) |
cuda():将张量或模型移动到GPU显存,加速计算
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
device | int/str | 否 | 指定GPU设备(如0 或"cuda:0" ) | tensor.cuda(device=0) |
non_blocking | bool | 否 | 是否异步传输数据(默认False) | tensor.cuda(non_blocking=True) |
# 标识是否使用gpu
cuda_flag = torch.cuda.is_available()
if cuda_flag:
logger.info("gpu可以使用,迁移模型至gpu")
model = model.cuda()
Ⅴ、加载优化器
#加载优化器
optimizer = choose_optimizer(config, model)
Ⅵ、加载评估器
#加载效果测试类
evaluator = Evaluator(config, model, logger)
Ⅶ、模型训练主流程
① Epoch循环控制
range():Python 内置函数,用于生成一个不可变的整数序列,核心功能是为循环控制提供高效的数值迭代支持
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
start | 整数 | 0 | 序列起始值(包含)。若省略,则默认从 0 开始。例如 range(3) 等价于 range(0,3) 。 |
stop | 整数 | 必填 | 序列结束值(不包含)。例如 range(2, 5) 生成 2,3,4 |
step | 整数 | 1 | 步长(正/负): - 正步长需满足 start < stop ,否则无输出(如 range(5, 2) 无效)。- 负步长需满足 start > stop ,例如 range(5, 0, -1) 生成 5,4,3,2,1 **不能为 0 **(否则触发 ValueError ) |
for epoch in range(config["epoch"]):
epoch += 1
② 模型设置训练模式
train_loss:计算当前批次的损失值,通常结合损失函数(如交叉熵、均方误差)使用
model.train():设置模型为训练模式,启用Dropout、BatchNorm等层的训练行为
参数 | 类型 | 默认值 | 说明 | 示例 |
---|---|---|---|---|
mode | bool | True | 是否启用训练模式(True)或评估模式(False) | model.train(True) |
logger.info():记录日志信息,输出训练过程中的关键状态
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
msg | str | 是 | 日志消息(支持格式化字符串) | logger.info("Epoch: %d", epoch) |
*args | Any | 否 | 格式化参数(用于% 占位符) |
model.train()
logger.info("epoch %d begin" % epoch)
train_loss = []
③ Batch数据遍历
enumerate():遍历可迭代对象时返回索引和元素,支持自定义起始索引
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
iterable | Iterable | 是 | 可迭代对象(如列表、生成器) | enumerate(["a", "b"]) |
start | int | 否 | 索引起始值(默认0) | enumerate(data, start=1) |
for index, batch_data in enumerate(train_data):
④ 梯度清零与设备切换
optimizer.zero_grad():清空模型参数的梯度,防止梯度累积
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
set_to_none | bool | 否 | 是否将梯度置为None (高效但危险) | optimizer.zero_grad(True) |
cuda():将张量或模型移动到GPU显存,加速计算
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
device | int/str | 否 | 指定GPU设备(如0 或"cuda:0" ) | tensor.cuda(device=0) |
non_blocking | bool | 否 | 是否异步传输数据(默认False) | tensor.cuda(non_blocking=True) |
optimizer.zero_grad()
if cuda_flag:
batch_data = [d.cuda() for d in batch_data]
⑤ 前向传播与损失计算
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
loss = model(input_id, labels)
⑥ 反向传播与参数更新
loss.backward():反向传播计算梯度,基于损失值更新模型参数的.grad
属性
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
retain_graph | bool | 否 | 是否保留计算图(用于多次反向传播) | loss.backward(retain_graph=True) |
optimizer.step():根据梯度更新模型参数,执行优化算法(如SGD、Adam)
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
closure | Callable | 否 | 重新计算损失的闭包函数(如LBFGS) | optimizer.step(closure) |
loss.backward()
optimizer.step()
⑦ 损失记录与日志输出
列表.append():在列表末尾添加元素,直接修改原列表
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
object | Any | 是 | 要添加到列表末尾的元素 | train_loss.append(loss.item()) |
int():将字符串或浮点数转换为整数,支持进制转换
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
x | str/float | 是 | 待转换的值(如字符串或浮点数) | int("10", base=2) (输出2进制10=2) |
base | int | 否 | 进制(默认10) |
len():返回对象(如列表、字符串)的长度或元素个数
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
obj | Sequence/Collection | 是 | 可计算长度的对象(如列表、字符串) | len([1, 2, 3]) (返回3) |
logger.info():记录日志信息,输出训练过程中的关键状态
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
msg | str | 是 | 日志消息(支持格式化字符串) | logger.info("Epoch: %d", epoch) |
*args | Any | 否 | 格式化参数(用于% 占位符) |
train_loss.append(loss.item())
if index % int(len(train_data) / 2) == 0:
logger.info("batch loss %f" % loss)
⑧ Epoch评估与日志
logger.info():记录日志信息,输出训练过程中的关键状态
参数 | 类型 | 必须 | 说明 | 示例 |
---|---|---|---|---|
msg | str | 是 | 日志消息(支持格式化字符串) | logger.info("Epoch: %d", epoch) |
*args | Any | 否 | 格式化参数(用于% 占位符) |
train_loss.append(loss.item())
if index % int(len(train_data) / 2) == 0:
logger.info("batch loss %f" % loss)
⑨ 完整训练代码
#训练
for epoch in range(config["epoch"]):
epoch += 1
model.train()
logger.info("epoch %d begin" % epoch)
train_loss = []
for index, batch_data in enumerate(train_data):
optimizer.zero_grad()
if cuda_flag:
batch_data = [d.cuda() for d in batch_data]
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
loss = model(input_id, labels)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
if index % int(len(train_data) / 2) == 0:
logger.info("batch loss %f" % loss)
logger.info("epoch average loss: %f" % np.mean(train_loss))
evaluator.eval(epoch)
Ⅷ、模型保存
os.path.join():用于跨平台路径拼接的核心函数,其核心功能是智能处理不同操作系统的路径分隔符,确保代码的可移植性和健壮性
参数名 | 类型 | 说明 |
---|---|---|
path | 字符串 | 必填参数,起始路径组件。 |
*paths | 可变参数 | 可接受多个路径组件,按顺序拼接。 |
model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)
# torch.save(model.state_dict(), model_path)
return model, train_data
4.调用模型预测
# -*- coding: utf-8 -*-
import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
"""
模型训练主程序
"""
def main(config):
#创建保存模型的目录
if not os.path.isdir(config["model_path"]):
os.mkdir(config["model_path"])
#加载训练数据
train_data = load_data(config["train_data_path"], config)
#加载模型
model = TorchModel(config)
# 标识是否使用gpu
cuda_flag = torch.cuda.is_available()
if cuda_flag:
logger.info("gpu可以使用,迁移模型至gpu")
model = model.cuda()
#加载优化器
optimizer = choose_optimizer(config, model)
#加载效果测试类
evaluator = Evaluator(config, model, logger)
#训练
for epoch in range(config["epoch"]):
epoch += 1
model.train()
logger.info("epoch %d begin" % epoch)
train_loss = []
for index, batch_data in enumerate(train_data):
optimizer.zero_grad()
if cuda_flag:
batch_data = [d.cuda() for d in batch_data]
input_id, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
loss = model(input_id, labels)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
if index % int(len(train_data) / 2) == 0:
logger.info("batch loss %f" % loss)
logger.info("epoch average loss: %f" % np.mean(train_loss))
evaluator.eval(epoch)
model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)
# torch.save(model.state_dict(), model_path)
return model, train_data
if __name__ == "__main__":
model, train_data = main(Config)