Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 95a012f

Browse filesBrowse files
authored
Adds command line to translate de model into code (#49)
* Adds command line to translate de model into code * doc
1 parent 5e3668d commit 95a012f
Copy full SHA for 95a012f

File tree

Expand file treeCollapse file tree

4 files changed

+174
-0
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+174
-0
lines changed

‎CHANGELOGS.rst

Copy file name to clipboardExpand all lines: CHANGELOGS.rst
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.1.3
55
+++++
66

7+
* :pr:`49`: adds command line to export a model into code
78
* :pr:`48`: support for subgraph in light API
89
* :pr:`47`: extends export onnx to code to support inner API
910
* :pr:`46`: adds an export to convert an onnx graph into light API code
+75Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import os
2+
import tempfile
3+
import unittest
4+
from contextlib import redirect_stdout
5+
from io import StringIO
6+
from onnx import TensorProto
7+
from onnx.helper import (
8+
make_graph,
9+
make_model,
10+
make_node,
11+
make_opsetid,
12+
make_tensor_value_info,
13+
)
14+
from onnx_array_api.ext_test_case import ExtTestCase
15+
from onnx_array_api._command_lines_parser import (
16+
get_main_parser,
17+
get_parser_translate,
18+
main,
19+
)
20+
21+
22+
class TestCommandLines1(ExtTestCase):
23+
def test_main_parser(self):
24+
st = StringIO()
25+
with redirect_stdout(st):
26+
get_main_parser().print_help()
27+
text = st.getvalue()
28+
self.assertIn("translate", text)
29+
30+
def test_parser_translate(self):
31+
st = StringIO()
32+
with redirect_stdout(st):
33+
get_parser_translate().print_help()
34+
text = st.getvalue()
35+
self.assertIn("model", text)
36+
37+
def test_command_translate(self):
38+
X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
39+
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6])
40+
Z = make_tensor_value_info("Z", TensorProto.FLOAT, [None, None])
41+
graph = make_graph(
42+
[
43+
make_node("Add", ["X", "Y"], ["res"]),
44+
make_node("Cos", ["res"], ["Z"]),
45+
],
46+
"g",
47+
[X, Y],
48+
[Z],
49+
)
50+
onnx_model = make_model(graph, opset_imports=[make_opsetid("", 18)])
51+
52+
with tempfile.TemporaryDirectory() as root:
53+
model_file = os.path.join(root, "model.onnx")
54+
with open(model_file, "wb") as f:
55+
f.write(onnx_model.SerializeToString())
56+
57+
args = ["translate", "-m", model_file]
58+
st = StringIO()
59+
with redirect_stdout(st):
60+
main(args)
61+
62+
code = st.getvalue()
63+
self.assertIn("model = make_model(", code)
64+
65+
args = ["translate", "-m", model_file, "-a", "light"]
66+
st = StringIO()
67+
with redirect_stdout(st):
68+
main(args)
69+
70+
code = st.getvalue()
71+
self.assertIn("start(opset=", code)
72+
73+
74+
if __name__ == "__main__":
75+
unittest.main(verbosity=2)

‎onnx_array_api/__main__.py

Copy file name to clipboard
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from ._command_lines_parser import main
2+
3+
if __name__ == "__main__":
4+
main()
+94Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import sys
2+
import onnx
3+
from typing import Any, List, Optional
4+
from argparse import ArgumentParser
5+
from textwrap import dedent
6+
7+
8+
def get_main_parser() -> ArgumentParser:
9+
parser = ArgumentParser(
10+
prog="onnx-array-api",
11+
description="onnx-array-api main command line.",
12+
epilog="Type 'python -m onnx_array_api <cmd> --help' "
13+
"to get help for a specific command.",
14+
)
15+
parser.add_argument(
16+
"cmd",
17+
choices=["translate"],
18+
help=dedent(
19+
"""
20+
Selects a command.
21+
22+
'translate' exports an onnx graph into a piece of code replicating it.
23+
"""
24+
),
25+
)
26+
return parser
27+
28+
29+
def get_parser_translate() -> ArgumentParser:
30+
parser = ArgumentParser(
31+
prog="translate",
32+
description=dedent(
33+
"""
34+
Translates an onnx model into a piece of code to replicate it.
35+
The result is printed on the standard output.
36+
"""
37+
),
38+
epilog="This is mostly used to write unit tests without adding "
39+
"an onnx file to the repository.",
40+
)
41+
parser.add_argument(
42+
"-m",
43+
"--model",
44+
type=str,
45+
required=True,
46+
help="onnx model to translate",
47+
)
48+
parser.add_argument(
49+
"-a",
50+
"--api",
51+
choices=["onnx", "light"],
52+
default="onnx",
53+
help="API to choose, API from onnx package or light API.",
54+
)
55+
return parser
56+
57+
58+
def _cmd_translate(argv: List[Any]):
59+
from .light_api import translate
60+
61+
parser = get_parser_translate()
62+
args = parser.parse_args(argv[1:])
63+
onx = onnx.load(args.model)
64+
code = translate(onx, api=args.api)
65+
print(code)
66+
67+
68+
def main(argv: Optional[List[Any]] = None):
69+
fcts = dict(translate=_cmd_translate)
70+
71+
if argv is None:
72+
argv = sys.argv[1:]
73+
if (len(argv) <= 1 and argv[0] not in fcts) or argv[-1] in ("--help", "-h"):
74+
if len(argv) < 2:
75+
parser = get_main_parser()
76+
parser.parse_args(argv)
77+
else:
78+
parsers = dict(translate=get_parser_translate)
79+
cmd = argv[0]
80+
if cmd not in parsers:
81+
raise ValueError(
82+
f"Unknown command {cmd!r}, it should be in {list(sorted(parsers))}."
83+
)
84+
parser = parsers[cmd]()
85+
parser.parse_args(argv[1:])
86+
raise RuntimeError("The programme should have exited before.")
87+
88+
cmd = argv[0]
89+
if cmd in fcts:
90+
fcts[cmd](argv)
91+
else:
92+
raise ValueError(
93+
f"Unknown command {cmd!r}, use --help to get the list of known command."
94+
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.