5
ONNX模型构造与代码检查
source link: https://leezhao415.github.io/2022/07/26/ONNX%E6%A8%A1%E5%9E%8B%E6%9E%84%E9%80%A0%E4%B8%8E%E4%BB%A3%E7%A0%81%E6%A3%80%E6%9F%A5/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
ONNX模型构造与代码检查
文章目录
ONNX 模型构造与代码检查
参考博客:https://zhuanlan.zhihu.com/p/516920606
1 构造描述张量信息的对象 ValueInfoProto
import onnx
from onnx import helper
from onnx import TensorProto
a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])
2 构造算子节点信息 NodeProto
mul = helper.make_node('Mul', ['a', 'x'], ['c'])
add = helper.make_node('Add', ['c', 'b'], ['output'])
3 构造计算图 GraphProto
graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output])
4 封装计算图
用 helper.make_model
把计算图 GraphProto
封装进模型 ModelProto
model = helper.make_model(graph)
5 检查代码
onnx.checker.check_model(model)
print(model)
onnx.save(model, 'linear_func.onnx')
ONNX Python API 构造模型完整代码
import onnx
from onnx import helper
from onnx import TensorProto
# input and output
a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])
# Mul
mul = helper.make_node('Mul', ['a', 'x'], ['c'])
# Add
add = helper.make_node('Add', ['c', 'b'], ['output'])
# graph and model
graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output])
model = helper.make_model(graph)
# save model
onnx.checker.check_model(model)
print(model)
onnx.save(model, 'linear_func.onnx')
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK