python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > pytorch模型部署onnx

pytorch模型部署到onnx的详细过程

作者:Liang-ml

这篇文章主要介绍了如何简单的将pytorch模型部署到onnx,本文结合示例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

1. 创建一个pytorch模型

这里我用的U2Net,直接加载好训练出的权重

model = U2Net(class_nums=4)
model.load_state_dict(torch.load(checkpoint_path))

2. 将pytorch模型转成onnx格式

x = torcg.randn(1,3,512,512)
with torch.no_grad():
torch.onnx.export(
model,
x,
"onnx_file_name.onnx",
opset_version=11,
input_names=['input'],
output_names=['output']
)

torch.onnx.export()函数用于将pytorch模型转成onnx格式的函数,其中的参数如下

转换成功会在指定的路径f下生成onnx模型文件

可用onnx中的函数验证模型文件是否正确

import onnx
onnx_model = onnx.load("onnx_file_name.onnx")
try:
onnx.checker.check_model(onnx_model)
except Exception as ex:
print(f'ERROR: {ex}')
else:
print(f'model correct')

可以使用netron查看转换的onnx模型的详细内容 链接 https://netron.app/

点击某个算子节点,可以查看算子的具体信息,每个算子记录了算子的属性、图结构和权重三类信息

3. 推理引擎ONNXRuntime

使用onnxruntime可以运行输出的onnx模型,代码如下

input_img 是一个与onnx模型输入维度一致的归一化的np.array
model = onnxruntime.InferenceSession(onnx_file)
inputs = {'input':input_img}
output = model.run(['output'],inputs)[0]

onnxruntime.InferenceSession用于获取一个ONNX Runtime的推理器,其参数为用于推理的ONNX模型文件

run方法用于模型推理,第一个参数为输出张量名的列表,第二个参数为输入值的字典,key为导出模型时设置的张量名,value为输入张量

到此这篇关于简单的将pytorch模型部署到onnx的文章就介绍到这了,更多相关pytorch模型部署到onnx内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:
阅读全文