python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > pytorch模型转onnx可视化

pytorch模型转换为onnx可视化(使用netron)

作者:michaelchengjl

netron 是一个非常好用的网络结构可视化工具,但是netron对pytorch模型的支持还不成熟,这篇文章主要介绍了pytorch模型转换为onnx,并使用netron可视化,需要的朋友可以参考下

pytorch模型转换为onnx,并使用netron可视化

netron 是一个非常好用的网络结构可视化工具。

但是netron对pytorch模型的支持还不成熟。自己试的效果是生成的模型图没有连线。

目前支持的框架 根据netron的github

目前netron支持:

ONNX (.onnx, .pb, .pbtxt)
Keras (.h5, .keras)
Core ML (.mlmodel)
Caffe (.caffemodel, .prototxt)
Caffe2 (predict_net.pb, predict_net.pbtxt)
Darknet (.cfg)
MXNet (.model, -symbol.json)
ncnn (.param) 
TensorFlow Lite (.tflite)
PaddlePaddle (.zip, model)
TensorFlow.js
CNTK (.model, .cntk)

并且实验性支持:

TorchScript (.pt, .pth)
PyTorch (.pt, .pth)
Torch (.t7)
Arm NN (.armnn)
BigDL (.bigdl, .model) 
Chainer (.npz, .h5)
Deeplearning4j (.zip)
MediaPipe (.pbtxt)
ML.NET (.zip), MNN (.mnn)
OpenVINO (.xml)
scikit-learn (.pkl)
TensorFlow (.pb, .meta, .pbtxt, .ckpt, .index)

Netron supports ONNX, TensorFlow Lite, Caffe, Keras, Darknet, PaddlePaddle, ncnn, MNN, Core ML, RKNN, MXNet, MindSpore Lite, TNN, Barracuda, Tengine, CNTK, TensorFlow.js, Caffe2 and UFF.

Netron has experimental support for PyTorch, TensorFlow, TorchScript, OpenVINO, Torch, Vitis AI, kmodel, Arm NN, BigDL, Chainer, Deeplearning4j, MediaPipe, ML.NET and scikit-learn.

这里就有一个把 .pth 模型转化为 .onnx 模型。

Pytorch模型转onnx

model = resnet18(pretrained=True)
# print(model)
# old_net_path = "resnet18.pth"
new_net_path = "./resnet18.onnx"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 导入模型
net = model.to(device)
# net.load_state_dict(torch.load(old_net_path, map_location=device))
net.eval()
input = torch.randn(1, 3, 224, 224).to(device)  # BCHW  其中Batch必须为1,因为测试时一般为1,尺寸HW必须和训练时的尺寸一致
torch.onnx.export(net, input, new_net_path, verbose=False)

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None)

参数:

model(torch.nn.Module)-要被导出的模型
args(参数的集合)-模型的输入,例如,这种model(*args)方式是对模型的有效调用。任何非Variable参数都将硬编码到导出的模型中;任何Variable参数都将成为导出的模型的输入,并按照他们在args中出现的顺序输入。如果args是一个Variable,这等价于用包含这个Variable的1-ary元组调用它。(注意:现在不支持向模型传递关键字参数。)
f-一个类文件的对象(必须实现文件描述符的返回)或一个包含文件名字符串。一个二进制Protobuf将会写入这个文件中。
export_params(bool,default True)-如果指定,所有参数都会被导出。如果你只想导出一个未训练的模型,就将此参数设置为False。在这种情况下,导出的模型将首先把所有parameters作为参arguments,顺序由model.state_dict().values()指定。
verbose(bool,default False)-如果指定,将会输出被导出的轨迹的调试描述。
training(bool,default False)-导出训练模型下的模型。目前,ONNX只面向推断模型的导出,所以一般不需要将该项设置为True。
input_names(list of strings, default empty list)-按顺序分配名称到图中的输入节点。
output_names(list of strings, default empty list)-按顺序分配名称到图中的输出节点。

文件中保存模型结构和权重参数

import torch
torch_model = torch.load("save.pt") # pytorch模型加载
batch_size = 1  #批处理大小
input_shape = (3,244,244)   #输入数据
# set the model to inference mode
torch_model.eval()
x = torch.randn(batch_size,*input_shape)		# 生成张量
export_onnx_file = "test.onnx"					# 目的ONNX文件名
torch.onnx.export(torch_model,
                    x,
                    export_onnx_file,
                    opset_version=10,
                    do_constant_folding=True,	# 是否执行常量折叠优化
                    input_names=["input"],		# 输入名
                    output_names=["output"],	# 输出名
                    dynamic_axes={"input":{0:"batch_size"},		# 批处理变量
                    "output":{0:"batch_size"}})

dynamic_axes字段用于批处理.若不想支持批处理或固定批处理大小,移除dynamic_axes字段即可.

文件中只保留模型权重

import torch
torch_model = selfmodel()  					# 由研究员提供python.py文件
batch_size = 1 								# 批处理大小
input_shape = (3, 244, 244) 				# 输入数据
# set the model to inference mode
torch_model.eval()
x = torch.randn(batch_size,*input_shape) 	# 生成张量
export_onnx_file = "test.onnx" 				# 目的ONNX文件名
torch.onnx.export(torch_model,
                    x,
                    export_onnx_file,
                    opset_version=10,
                    do_constant_folding=True,	# 是否执行常量折叠优化
                    input_names=["input"],		# 输入名
                    output_names=["output"],	# 输出名
                    dynamic_axes={"input":{0:"batch_size"},	# 批处理变量
                    "output":{0:"batch_size"}})

到此这篇关于pytorch模型转换为onnx可视化(使用netron)的文章就介绍到这了,更多相关pytorch模型转onnx可视化内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:
阅读全文