模型转换
pytorch转onnx,并用onnxruntime验证模型转换是否正确
import torch.utils.data
from torch.autograd import Variable
from squeezenet import squeezenet1_2
import onnxruntime
import numpy as np
from onnxruntime.datasets import get_example
import cv2
import onnx
model_file = '/home/bbt/qinghua/detetion/pytorch-mobilenet/model/model_best_squeezenet1_2_age0.926.pth.tar'
num_class =3
# create model
model=squeezenet1_2(pretrained=False,num_classes=num_class)
model = torch.nn.DataParallel(model)
checkpoint = torch.load(model_file)
model.load_state_dict(checkpoint['state_dict'])
model.to('cpu')
model.eval()
input=cv2.imread('/home/bbt/age00.jpg')
input=cv2.resize(input,(224,224))
input=np.transpose(input, (2, 0, 1)).astype(np.float32)
now_image1 = Variable(torch.from_numpy(input))
dummy_input = now_image1.unsqueeze(0)
input_names=['i