Bootstrap

2、TensorRT学习笔记之PT转ONNX、可视化ONNX

        摘要:主要是解析PT权重转ONNX的过程、代码。文章末尾附有完整运行代码。

目录:

        1、导入模型

        2、 关闭梯度更新

        3、精度转换、fuse等

        4、参数的设置

        5、 pt转onnx

        6、添加其余信息

        7、保存onnx

        8、 完整代码

        9、ONNX可视化


1、导入模型

model  = torch.load("yolov8n.pt")  

# 由于我用v8做测试,v8的权重中除了模型结构,还有配置文件的参数,这里只需要其模型结构
ckpt = torch.load("yolov8n.pt")    
model = ckpt['model']

2、 关闭梯度更新

        作用:减少计算所需的资源

for p in model.parameters():
    p.requires_grad = False

3、精度转换、fuse等

        fuse的作用:将conv和BN层融合,提高推理速度

model.eval()
model.float()
model = model.fuse()

4、参数的设置

device = 'cpu'        # 如果cuda能用,设置'cuda:0'
input = torch.zeros((1, 3, 640, 640)).to(device)      # 根据模型输入尺寸设置(1, 3, 640, 640)
f = 'name.onnx'                # 文件名,保存路径
opset_version = 11             # opset版本
input_names = ['img']          # 输入名
output_names=['out']           # 输出名
# 通过以下规则设置动态的维度
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},  # shape(1,3,640,640)
                         'output': {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)}        

5、 pt转onnx

torch.onnx.export(
    model,                            # 模型pt权重
    input,                             # 输入张量,模型输入,如[1,3,640,640]
    f,                                 # 保存onnx模型的文件
    opset_version = opset_version,     # Opset版本
    input_names=input_names,           # 输入张量名称
    output_names=output_names,         # 输出的张量名称
    dynamic_axes=dynamic_axes,        # 通过以下规则设置动态的维度
)

也可以结合步骤四直接这么写:

torch.onnx.export(model = torch.load(weight)[model],
                  im = torch.zeros(1, 3, 640, 640), 
                  f = ./weight.onnx, 
                  export_params = 12,
                  input_names=['images'],
                  output_names=['output'],
                  dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, 
                                'output': {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)}) 

6、添加其余信息

        比如你想在onnx中保存一些模型之外的信息,比如作者、模型名、classes、Opset版本等信息的话。

metadata = {
            'author': 'NO郑',
            'version': '11',
            'stride': int(max(model.stride)),
            'task': 'yolov8',
            'batch': 1,
            'imgsz': (640, 640),
            'labels': ['car', 'person']}  # model metadata
model_onnx = onnx.load(f)
for k, v in metadata.items():
    meta = model_onnx.metadata_props.add()
    meta.key, meta.value = k, str(v)

7、保存onnx

onnx.save(model, f)

8、 完整代码

import torch
import onnx

ckpt = torch.load("yolov8n.pt")
model = ckpt['model']
im = torch.zeros((1, 3, 640, 640)).to('cpu')
for p in model.parameters():
    p.requires_grad = False

model.eval()
model.float()
model = model.fuse()
f = 'onnx01.onnx'

torch.onnx.export(model.cpu(), im, f, opset_version=11, input_names=['img'], output_names=['out'], dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},  # shape(1,3,640,640)
                                'output': {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)
                                })

metadata = {
            'description': '1',
            'author': 'NO郑',
            'version': '11',
            'batch': 1,
            'imgsz': (640, 640),
            'names': ['class1', 'classes2']}  # model metadata
model_onnx = onnx.load(f)
for k, v in metadata.items():
    meta = model_onnx.metadata_props.add()
    meta.key, meta.value = k, str(v)
onnx.save(model_onnx, f)

9、ONNX可视化

  •         网页搜索netron
  •         将ONNX文件打开即可

        其中:METADATA是第6步添加的信息

上一篇:1、TensorRT学习笔记之安装TensorRT

下一篇:3、TensorRT学习笔记之ONNX转engine

;