pytorch模型部署 pth转onnx的方法
作者:aoyou19 发布时间:2022-07-05 03:49:04
Pytorch转ONNX的意义
一般来说转ONNX只是一个手段,在之后得到ONNX模型后还需要再将它做转换,比如转换到TensorRT上完成部署,或者有的人多加一步,从ONNX先转换到caffe,再从caffe到tensorRT。Pytorch自带的torch.onnx.export转换得到的ONNX,ONNXRuntime需要的ONNX,TensorRT需要的ONNX都是不同的。
将pytorch训练保存的pth文件转为onnx文件,为后续模型部署做准备。
一、分类模型
import torch
import os
import timm
import argparse
from utils_net import Resnet
parser = argparse.ArgumentParser()
parser.add_argument("--pth_path", default='classify_model.pth')
parser.add_argument("--save_onnx_path", default='classify_model.onnx')
parser.add_argument("--input_width", default=416)
parser.add_argument("--input_height", default=416)
parser.add_argument("--input_channel", default=1)
parser.add_argument("--num_classes", default=6)
args = parser.parse_args()
def pth_to_onnx(pth_path, onnx_path, in_hig, in_wid, in_chal, num_cls):
if not onnx_path.endswith('.onnx'):
print('Warning! The onnx model name is not correct,\
please give a name that ends with \'.onnx\'!')
return 0
model = Resnet(num_classes=num_cls)
model.load_state_dict(torch.load(pth_path))
model.eval()
print(f'{pth_path} model loaded')
input_names = ['input']
output_names = ['output']
im = torch.rand(1, in_chal, in_hig, in_wid)
torch.onnx.export(model, im, onnx_path,
verbose=False,
input_names=input_names,
output_names=output_names)
print("Exporting .pth model to onnx model has been successful!")
print(f"Onnx model save as {onnx_path}")
if __name__ == '__main__':
pth_to_onnx(pth_path=args.pth_path,
onnx_path=args.save_onnx_path,
in_hig=args.input_height,
in_wid=args.input_width,
in_chal=args.input_channel,
num_cls=args.num_classes)
运行结果:
classify_model.pth model loaded
Exporting .pth model to onnx model has been successful!
Onnx model save as classify_model.onnxProcess finished with exit code 0
二、分割模型
import torch
import os
import argparse
from utils_net import seg_net
parser = argparse.ArgumentParser()
parser.add_argument("--pth_path", default='segment_model.pth')
parser.add_argument("--save_onnx_path", default='segment_model.onnx')
parser.add_argument("--input_width", default=416)
parser.add_argument("--input_height", default=416)
parser.add_argument("--input_channel", default=1)
parser.add_argument("--num_classes", default=4)
args = parser.parse_args()
def pth_to_onnx(pth_path, onnx_path, in_hig, in_wid, in_channel, num_cls):
if not onnx_path.endswith('.onnx'):
print('Warning! The onnx model name is not correct,\
please give a name that ends with \'.onnx\'!')
return 0
model = seg_net(in_channel=in_channel, num_cls=num_cls)
model.load_state_dict(torch.load(pth_path))
model.eval()
print(f'{pth_path} model loaded')
input_names = ['input']
output_names = ['output']
im = torch.rand(1, in_channel, in_hig, in_wid)
torch.onnx.export(model, im, onnx_path,
verbose=False,
input_names=input_names,
output_names=output_names,
opset_version=11)
print("Exporting .pth model to onnx model has been successful!")
print(f"Onnx model save as {onnx_path}")
if __name__ == '__main__':
pth_to_onnx(pth_path=args.pth_path,
onnx_path=args.save_onnx_path,
in_hig=args.input_height,
in_wid=args.input_width,
in_channel=args.input_channel,
num_cls=args.num_classes)
运行结果:
segment_model.pth model loaded
Exporting .pth model to onnx model has been successful!
Onnx model save as segment_model.onnxProcess finished with exit code 0
三、目标检测模型
在这里插入代码片
import torch
import onnx
import argparse
from utils_net import YoloBody
parser = argparse.ArgumentParser()
parser.add_argument("--pth_path", default='yolo.pth')
parser.add_argument("--save_onnx_path", default='yolo.onnx')
parser.add_argument("--input_width", default=416)
parser.add_argument("--input_height", default=416)
parser.add_argument("--num_classes", default=2)
parser.add_argument("--anchors_mask", default=[[6, 7, 8], [3, 4, 5], [0, 1, 2]])
args = parser.parse_args()
def pth_to_onnx(pth_path: str, save_onnx_path: str, num_cls: int,
in_hig: int, in_wid: int, anchor_mask: list,
opset_version: int = 12, simplify: bool = False):
"""
:param pth_path: pth文件文件
:param save_onnx_path: 准备保存的onnx路径
:param num_cls: 检测目标类别数
:param in_hig: 网络输入高度
:param in_wid: 网络输入宽度
:param anchor_mask: anchor宽高索引
:param opset_version: onnx算子集版本
:param simplify: 是否对模型进行简化
:return:保存onnx到指定路径
"""
# Build model, load weights
net = YoloBody(anchors_mask=anchor_mask,
num_classes=num_cls)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# net.load_state_dict(torch.load(pth_path, map_location=device))
net.load_state_dict(torch.load(pth_path))
# print(next(net.parameters()).device)
net = net.eval()
print(f'{pth_path} model loaded')
im = torch.zeros(1, 3, in_hig, in_wid).to('cpu')
input_layer_names = ['images']
output_layer_names = ['output']
# Export the model
print(f'Starting export with onnx {onnx.__version__}.')
torch.onnx.export(net,
im,
f=save_onnx_path,
verbose=False,
opset_version=opset_version,
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True,
input_names=input_layer_names,
output_names=output_layer_names,
dynamic_axes=None)
# Checks
model_onnx = onnx.load(save_onnx_path) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# Simplify onnx
if simplify:
import onnxsim
print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')
model_onnx, check = onnxsim.simplify(
model_onnx,
dynamic_input_shape=False,
input_shapes=None)
assert check, 'assert check failed'
onnx.save(model_onnx, save_onnx_path)
print('Onnx model save as {}'.format(save_onnx_path))
if __name__ == '__main__':
pth_to_onnx(pth_path=args.pth_path,
save_onnx_path=args.save_onnx_path,
num_cls=args.num_classes,
in_hig=args.input_height,
in_wid=args.input_width,
anchor_mask=args.anchors_mask)
运行结果:
yolo.pth model loaded
Starting export with onnx 1.11.0.
Onnx model save as yolo.onnxProcess finished with exit code 0
参考链接:
1.yolo
2.模型部署翻车记:pytorch转onnx踩坑实录
来源:https://blog.csdn.net/aoyou19/article/details/129407797


猜你喜欢
- 因为在做一个项目需要筛选掉一部分产品列表中的产品,使其在列表显示时排在最后,但是所有产品都要按照更新时间排序。研究了一下系统的数据库结构后,
- 如何远程注册DLL?试试下面的代码:<% Response.Buffer = True %&g
- 方法1: 用SET PASSWORD命令mysql -u rootmysql> SET PASSWORD FOR 'root&
- Python中有以下几个基本的数据类型:整数 int字符串 str浮点数 float集合 set列表 list元组 tuple字典 dict
- 本文实例讲述了Python实现操纵控制windows注册表的方法。分享给大家供大家参考,具体如下:使用_winreg模块的话基本概念:KEY
- 随着移动端的用户越来越多,传统的web系统架构无法兼容很多移动终端的正常使用。在工作中也会发现,现在很多的客户都有在手机、平板等移动终端上使
- 如下所示: m_start =date +' 09:00' m_end =date +' 13:00'rsv
- wechat_sender 是基于 wxpy 和 tornado 实现的一个可以将你的网站、爬虫、脚本等其他应用中各种消息 (日志、报警、运
- 一、django的模板:在settings.py的文件中可以看到并设置这个模板。1.直接映射:通过建立的文件夹(templates)和文件(
- 本文实例为大家分享了python使用itchat实现手机控制电脑的具体代码,供大家参考,具体内容如下1.准备材料首先电脑上需要安装了pyth
- mysql中写判断语句的方法:方法一.CASE函数case函数语法:CASE conditionWHEN value1 THEN retur
- 插值主要用于物理学数学中,逼近某一确定值的方法(1)插值是通过已知的离散数据求未知数据的方法。(2)与拟合不同,插值要求曲线通过所有的已知数
- 随着新技术的不断发展,JavaScript已经不再仅仅只是一个网络语言。现在,我们能够看到很多使用JavaScript来构建基于本地浏览器的
- 一. meta方法打包好的入口index.html头部加入<META HTTP-EQUIV="pragma" CO
- python数组进行降维在深度学习训练过程中,我们有时候想要输出图片看看图片长什么样,但是训练时的图片格式一般都会多出一个批次的维度,如[1
- Python中单类继承Python是一门面向对象的编程语言,支持类继承。新的类称为子类(Subclass),被继承的类称为父类、基类或者超类
- 本文实例讲述了Python使用Matplotlib实现雨点图动画效果的方法。分享给大家供大家参考,具体如下:关键点win10安装ffmpeg
- 系统环境centos7python2.7先在操作系统安装expect[root@V71 python]# vi 3s.py#!/usr/bi
- 视频加密流程图:后端获取保利威的视频播放授权token,提供接口api给前端参考文档:http://dev.polyv.net/2019/v
- JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式,易于人阅读和编写。JSON 函数使用 JSON