Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3de3c5d

Browse filesBrowse files
authored
Supports for local functions in translator (#96)
* fix suffix * one fix * fix * fix ut * fix ir_version * doc
1 parent 664e084 commit 3de3c5d
Copy full SHA for 3de3c5d

File tree

6 files changed

+257
-32
lines changed
Filter options

6 files changed

+257
-32
lines changed

‎CHANGELOGS.rst

Copy file name to clipboardExpand all lines: CHANGELOGS.rst
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.3.1
55
+++++
66

7+
* :pr:`96`: supports local functions in translator
78
* :pr:`95`: improves translation to GraphBuilder
89

910
0.3.0

‎_unittests/ut_translate_api/test_translate_builder.py

Copy file name to clipboardExpand all lines: _unittests/ut_translate_api/test_translate_builder.py
+125-19Lines changed: 125 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
from textwrap import dedent
33
import numpy as np
4+
import onnx.helper as oh
45
from onnx import ModelProto, TensorProto
56
from onnx.checker import check_model
67
from onnx.defs import onnx_opset_version
@@ -29,37 +30,43 @@ def test_exp(self):
2930
self.assertEqualArray(np.exp(a), got)
3031

3132
code = translate(onx, api="builder")
32-
expected = dedent(
33-
"""
33+
expected = (
34+
dedent(
35+
"""
3436
def light_api(
3537
op: "GraphBuilder",
3638
X: "FLOAT[]",
3739
):
38-
Y = op.Exp(X)
40+
Y = op.Exp(X, outputs=['Y'])
3941
op.Identity(Y, outputs=["Y"])
4042
return Y
4143
4244
g = GraphBuilder({'': 19}, ir_version=10)
4345
g.make_tensor_input("X", TensorProto.FLOAT, ())
4446
light_api(g.op, "X")
45-
g.make_tensor_output("Y", TensorProto.FLOAT, ())
47+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
4648
model = g.to_onnx()
4749
"""
48-
).strip("\n")
50+
)
51+
.strip("\n")
52+
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
53+
)
4954
self.assertEqual(expected, code.strip("\n"))
5055

5156
def light_api(
5257
op: "GraphBuilder",
5358
X: "FLOAT[]", # noqa: F722
5459
):
55-
Y = op.Exp(X)
60+
Y = op.Exp(X, outputs=["Y"])
5661
op.Identity(Y, outputs=["Y"])
5762
return Y
5863

5964
g2 = GraphBuilder({"": 19})
6065
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
6166
light_api(g2.op, "X")
62-
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
67+
g2.make_tensor_output(
68+
"Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False
69+
)
6370
onx2 = g2.to_onnx()
6471

6572
ref = ReferenceEvaluator(onx2)
@@ -78,25 +85,29 @@ def test_zdoc(self):
7885
.to_onnx()
7986
)
8087
code = translate(onx, api="builder")
81-
expected = dedent(
82-
"""
88+
expected = (
89+
dedent(
90+
"""
8391
def light_api(
8492
op: "GraphBuilder",
8593
X: "FLOAT[]",
8694
):
8795
r = np.array([-1, 1], dtype=np.int64)
88-
r0_0 = op.Reshape(X, r)
89-
Y = op.Transpose(r0_0, perm=[1, 0])
96+
r0_0 = op.Reshape(X, r, outputs=['r0_0'])
97+
Y = op.Transpose(r0_0, perm=[1, 0], outputs=['Y'])
9098
op.Identity(Y, outputs=["Y"])
9199
return Y
92100
93101
g = GraphBuilder({'': 19}, ir_version=10)
94102
g.make_tensor_input("X", TensorProto.FLOAT, ())
95103
light_api(g.op, "X")
96-
g.make_tensor_output("Y", TensorProto.FLOAT, ())
104+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
97105
model = g.to_onnx()
98106
"""
99-
).strip("\n")
107+
)
108+
.strip("\n")
109+
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
110+
)
100111
self.maxDiff = None
101112
self.assertEqual(expected, code.strip("\n"))
102113

@@ -130,13 +141,14 @@ def test_exp_f(self):
130141
tr = Translater(onx, emitter=BuilderEmitter("mm"))
131142
code = tr.export(as_str=True)
132143

133-
expected = dedent(
134-
"""
144+
expected = (
145+
dedent(
146+
"""
135147
def light_api(
136148
op: "GraphBuilder",
137149
X: "FLOAT[]",
138150
):
139-
Y = op.Exp(X)
151+
Y = op.Exp(X, outputs=['Y'])
140152
op.Identity(Y, outputs=["Y"])
141153
return Y
142154
@@ -145,14 +157,17 @@ def mm() -> "ModelProto":
145157
g = GraphBuilder({'': 19}, ir_version=10)
146158
g.make_tensor_input("X", TensorProto.FLOAT, ())
147159
light_api(g.op, "X")
148-
g.make_tensor_output("Y", TensorProto.FLOAT, ())
160+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
149161
model = g.to_onnx()
150162
return model
151163
152164
153165
model = mm()
154166
"""
155-
).strip("\n")
167+
)
168+
.strip("\n")
169+
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
170+
)
156171
self.assertEqual(expected, code.strip("\n"))
157172

158173
def light_api(
@@ -166,14 +181,105 @@ def light_api(
166181
g2 = GraphBuilder({"": 19})
167182
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
168183
light_api(g2.op, "X")
169-
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
184+
g2.make_tensor_output(
185+
"Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False
186+
)
170187
onx2 = g2.to_onnx()
171188

172189
ref = ReferenceEvaluator(onx2)
173190
a = np.arange(10).astype(np.float32)
174191
got = ref.run(None, {"X": a})[0]
175192
self.assertEqualArray(np.exp(a), got)
176193

194+
def test_local_function(self):
195+
new_domain = "custom"
196+
197+
linear_regression = oh.make_function(
198+
new_domain,
199+
"LinearRegression",
200+
["x", "a", "b"],
201+
["y"],
202+
[
203+
oh.make_node("MatMul", ["x", "a"], ["xa"]),
204+
oh.make_node("Add", ["xa", "b"], ["y"]),
205+
],
206+
[oh.make_opsetid("", 14)],
207+
[],
208+
)
209+
210+
graph = oh.make_graph(
211+
[
212+
oh.make_node(
213+
"LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain
214+
),
215+
oh.make_node("Abs", ["Y1"], ["Y"]),
216+
],
217+
"example",
218+
[
219+
oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]),
220+
oh.make_tensor_value_info("A", TensorProto.FLOAT, [None, None]),
221+
oh.make_tensor_value_info("B", TensorProto.FLOAT, [None, None]),
222+
],
223+
[oh.make_tensor_value_info("Y", TensorProto.FLOAT, None)],
224+
)
225+
226+
onnx_model = oh.make_model(
227+
graph,
228+
opset_imports=[oh.make_opsetid("", 14), oh.make_opsetid(new_domain, 1)],
229+
functions=[linear_regression],
230+
ir_version=10,
231+
)
232+
tr = Translater(onnx_model, emitter=BuilderEmitter("mm"))
233+
code = tr.export(as_str=True)
234+
235+
expected = (
236+
dedent(
237+
"""
238+
def example(
239+
op: "GraphBuilder",
240+
X: "FLOAT[, ]",
241+
A: "FLOAT[, ]",
242+
B: "FLOAT[, ]",
243+
):
244+
Y1 = op.LinearRegression(X, A, B, domain='custom', outputs=['Y1'])
245+
Y = op.Abs(Y1, outputs=['Y'])
246+
op.Identity(Y, outputs=["Y"])
247+
return Y
248+
249+
250+
def make_custom_LinearRegression(g: "GraphBuilder"):
251+
gr = GraphBuilder({'': 14}, as_function=True)
252+
x = gr.make_tensor_input('x')
253+
a = gr.make_tensor_input('a')
254+
b = gr.make_tensor_input('b')
255+
op = gr.op
256+
xa = op.MatMul(x, a, outputs=['xa'])
257+
y = op.Add(xa, b, outputs=['y'])
258+
gr.make_tensor_output(y)
259+
g.add_function(builder=gr)
260+
return gr
261+
262+
263+
def mm() -> "ModelProto":
264+
g = GraphBuilder({'': 14, 'custom': 1}, ir_version=10)
265+
g.make_tensor_input("X", TensorProto.FLOAT, ('', ''))
266+
g.make_tensor_input("A", TensorProto.FLOAT, ('', ''))
267+
g.make_tensor_input("B", TensorProto.FLOAT, ('', ''))
268+
example(g.op, "X", "A", "B")
269+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
270+
make_custom_LinearRegression(g)
271+
model = g.to_onnx()
272+
return model
273+
274+
275+
model = mm()
276+
"""
277+
)
278+
.strip("\n")
279+
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
280+
)
281+
self.assertEqual(expected, code.strip("\n"))
282+
177283

178284
if __name__ == "__main__":
179285
unittest.main(verbosity=2)

‎onnx_array_api/graph_api/graph_builder.py

Copy file name to clipboardExpand all lines: onnx_array_api/graph_api/graph_builder.py
+13Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def __init__(
194194
self._known_shapes = {}
195195
self._known_types = {}
196196
self.constants_ = {}
197+
self.functions_ = {}
197198
elif isinstance(target_opset_or_existing_proto, ModelProto):
198199
assert (
199200
not input_names
@@ -223,6 +224,8 @@ def __init__(
223224
self.constants_[node.output[0]] = node
224225
self.set_shape(node.output[0], self._get_tensor_shape(node))
225226
self.set_type(node.output[0], self._get_tensor_type(node))
227+
for f in proto.functions:
228+
self.add_function(f)
226229
else:
227230
raise NotImplementedError(
228231
f"{type(target_opset_or_existing_proto)} is not supported."
@@ -231,6 +234,14 @@ def __init__(
231234
self.op = Opset(self, self.opsets[""]) if "" in self.opsets else None
232235
self._cache_array = []
233236

237+
def add_local_function(self, domain: str, name: str, gr: "GraphBuilder"):
238+
"Adds a local function."
239+
assert (
240+
domain,
241+
name,
242+
) not in self.functions_, f"Function {(domain, name)} was already added."
243+
self.functions_[domain, name] = gr
244+
234245
def _get_tensor_shape(
235246
self, proto: Union[NodeProto, TensorProto]
236247
) -> Tuple[int, ...]:
@@ -417,6 +428,8 @@ def make_tensor_output(
417428
name: Union[str, List[str]],
418429
elem_type: Optional[int] = None,
419430
shape: Optional[Tuple[int, ...]] = None,
431+
is_dimension: bool = False,
432+
indexed: bool = False,
420433
) -> Union[str, List[str]]:
421434
if isinstance(name, list):
422435
res = []

‎onnx_array_api/translate_api/base_emitter.py

Copy file name to clipboardExpand all lines: onnx_array_api/translate_api/base_emitter.py
+28Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ class EventType(IntEnum):
2525
END_SIGNATURE = 16
2626
BEGIN_RETURN = 17
2727
END_RETURN = 18
28+
BEGIN_FUNCTION_SIGNATURE = 19
29+
END_FUNCTION_SIGNATURE = 20
30+
BEGIN_FUNCTION_RETURN = 21
31+
END_FUNCTION_RETURN = 22
2832

2933
@classmethod
3034
def to_str(cls, self) -> str:
@@ -76,6 +80,12 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
7680
if event == EventType.BEGIN_FUNCTION:
7781
return self._emit_begin_function(**kwargs)
7882

83+
if event == EventType.BEGIN_FUNCTION_SIGNATURE:
84+
return self._emit_begin_function_signature(**kwargs)
85+
86+
if event == EventType.END_FUNCTION_SIGNATURE:
87+
return self._emit_end_function_signature(**kwargs)
88+
7989
if event == EventType.END_FUNCTION:
8090
return self._emit_end_function(**kwargs)
8191

@@ -100,6 +110,12 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
100110
if event == EventType.END_RETURN:
101111
return self._emit_end_return(**kwargs)
102112

113+
if event == EventType.BEGIN_FUNCTION_RETURN:
114+
return self._emit_begin_function_return(**kwargs)
115+
116+
if event == EventType.END_FUNCTION_RETURN:
117+
return self._emit_end_function_return(**kwargs)
118+
103119
raise ValueError(f"Unexpected event {EventType.to_str(event)}.")
104120

105121
def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
@@ -224,6 +240,12 @@ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]:
224240
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
225241
)
226242

243+
def _emit_begin_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
244+
return []
245+
246+
def _emit_end_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
247+
return []
248+
227249
def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]:
228250
raise NotImplementedError(
229251
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
@@ -250,3 +272,9 @@ def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:
250272

251273
def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
252274
return []
275+
276+
def _emit_begin_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
277+
return []
278+
279+
def _emit_end_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
280+
return []

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.