摘要:主要是解析PT权重转ONNX的过程、代码。文章末尾附有完整运行代码。
目录:
1、导入模型
model = torch.load("yolov8n.pt")
# 由于我用v8做测试,v8的权重中除了模型结构,还有配置文件的参数,这里只需要其模型结构
ckpt = torch.load("yolov8n.pt")
model = ckpt['model']
2、 关闭梯度更新
作用:减少计算所需的资源
for p in model.parameters():
p.requires_grad = False
3、精度转换、fuse等
fuse的作用:将conv和BN层融合,提高推理速度
model.eval()
model.float()
model = model.fuse()
4、参数的设置
device = 'cpu' # 如果cuda能用,设置'cuda:0'
input = torch.zeros((1, 3, 640, 640)).to(device) # 根据模型输入尺寸设置(1, 3, 640, 640)
f = 'name.onnx' # 文件名,保存路径
opset_version = 11 # opset版本
input_names = ['img'] # 输入名
output_names=['out'] # 输出名
# 通过以下规则设置动态的维度
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)}
5、 pt转onnx
torch.onnx.export(
model, # 模型pt权重
input, # 输入张量,模型输入,如[1,3,640,640]
f, # 保存onnx模型的文件
opset_version = opset_version, # Opset版本
input_names=input_names, # 输入张量名称
output_names=output_names, # 输出的张量名称
dynamic_axes=dynamic_axes, # 通过以下规则设置动态的维度
)
也可以结合步骤四直接这么写:
torch.onnx.export(model = torch.load(weight)[model],
im = torch.zeros(1, 3, 640, 640),
f = ./weight.onnx,
export_params = 12,
input_names=['images'],
output_names=['output'],
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},
'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)})
6、添加其余信息
比如你想在onnx中保存一些模型之外的信息,比如作者、模型名、classes、Opset版本等信息的话。
metadata = {
'author': 'NO郑',
'version': '11',
'stride': int(max(model.stride)),
'task': 'yolov8',
'batch': 1,
'imgsz': (640, 640),
'labels': ['car', 'person']} # model metadata
model_onnx = onnx.load(f)
for k, v in metadata.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)
7、保存onnx
onnx.save(model, f)
8、 完整代码
import torch
import onnx
ckpt = torch.load("yolov8n.pt")
model = ckpt['model']
im = torch.zeros((1, 3, 640, 640)).to('cpu')
for p in model.parameters():
p.requires_grad = False
model.eval()
model.float()
model = model.fuse()
f = 'onnx01.onnx'
torch.onnx.export(model.cpu(), im, f, opset_version=11, input_names=['img'], output_names=['out'], dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
})
metadata = {
'description': '1',
'author': 'NO郑',
'version': '11',
'batch': 1,
'imgsz': (640, 640),
'names': ['class1', 'classes2']} # model metadata
model_onnx = onnx.load(f)
for k, v in metadata.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)
onnx.save(model_onnx, f)
9、ONNX可视化
- 网页搜索netron
- 将ONNX文件打开即可
其中:METADATA是第6步添加的信息