Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit a8b45f9

Browse filesBrowse files
authored
Replaces long initiliazer by rando values (#98)
* Replaces long initiliazer by rando values * fix display * fix issues
1 parent a868dd3 commit a8b45f9
Copy full SHA for a8b45f9

File tree

5 files changed

+136
-4
lines changed
Filter options

5 files changed

+136
-4
lines changed

‎_doc/api/translate_api.rst

Copy file name to clipboardExpand all lines: _doc/api/translate_api.rst
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ InnerEmitter
3939
.. autoclass:: onnx_array_api.translate_api.inner_emitter.InnerEmitter
4040
:members:
4141

42+
InnerEmitterShortInitializer
43+
++++++++++++++++++++++++++++
44+
45+
.. autoclass:: onnx_array_api.translate_api.inner_emitter.InnerEmitterShortInitializer
46+
:members:
47+
4248
LightEmitter
4349
++++++++++++
4450

‎_unittests/ut_ort/test_ort_profile.py

Copy file name to clipboardExpand all lines: _unittests/ut_ort/test_ort_profile.py
-2Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ def myloss(x, y):
5757
prof = ort_profile(optimized, feeds)
5858
events = {
5959
"kernel_time",
60-
"fence_before",
61-
"fence_after",
6260
"SequentialExecutor::Execute",
6361
"model_run",
6462
"model_loading_array",

‎_unittests/ut_translate_api/test_translate_classic.py

Copy file name to clipboardExpand all lines: _unittests/ut_translate_api/test_translate_classic.py
+69Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,75 @@ def test_transpose(self):
178178
self.maxDiff = None
179179
self.assertEqual(expected, code)
180180

181+
def test_transpose_short(self):
182+
onx = (
183+
start(opset=19)
184+
.vin("X")
185+
.reshape((-1, 1))
186+
.Transpose(perm=[1, 0])
187+
.rename("Y")
188+
.vout()
189+
.to_onnx()
190+
)
191+
self.assertIsInstance(onx, ModelProto)
192+
self.assertIn("Transpose", str(onx))
193+
ref = ReferenceEvaluator(onx)
194+
a = np.arange(10).astype(np.float32)
195+
got = ref.run(None, {"X": a})[0]
196+
self.assertEqualArray(a.reshape((-1, 1)).T, got)
197+
198+
code = translate(onx, api="onnx-short")
199+
expected = dedent(
200+
"""
201+
opset_imports = [
202+
make_opsetid('', 19),
203+
]
204+
inputs = []
205+
outputs = []
206+
nodes = []
207+
initializers = []
208+
sparse_initializers = []
209+
functions = []
210+
initializers.append(
211+
from_array(
212+
np.array([-1, 1], dtype=np.int64),
213+
name='r'
214+
)
215+
)
216+
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
217+
nodes.append(
218+
make_node_extended(
219+
'Reshape',
220+
['X', 'r'],
221+
['r0_0']
222+
)
223+
)
224+
nodes.append(
225+
make_node_extended(
226+
'Transpose',
227+
['r0_0'],
228+
['Y'],
229+
perm=[1, 0]
230+
)
231+
)
232+
outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[]))
233+
graph = make_graph(
234+
nodes,
235+
'light_api',
236+
inputs,
237+
outputs,
238+
initializers,
239+
sparse_initializer=sparse_initializers,
240+
)
241+
model = make_model(
242+
graph,
243+
functions=functions,
244+
opset_imports=opset_imports
245+
)"""
246+
).strip("\n")
247+
self.maxDiff = None
248+
self.assertEqual(expected, code)
249+
181250
def test_topk_reverse(self):
182251
onx = (
183252
start(opset=19)

‎onnx_array_api/translate_api/__init__.py

Copy file name to clipboardExpand all lines: onnx_array_api/translate_api/__init__.py
+6-2Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from onnx import ModelProto
22
from .translate import Translater
3-
from .inner_emitter import InnerEmitter
3+
from .inner_emitter import InnerEmitter, InnerEmitterShortInitializer
44
from .builder_emitter import BuilderEmitter
55

66

@@ -16,7 +16,8 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
1616
:class:`onnx_array_api.translate_api.light_emitter.LightEmitter`,
1717
another value is `"onnx"` which is the inner API implemented
1818
in onnx package, `"builder"` follows the syntax for the
19-
class :class:`onnx_array_api.graph_api.GraphBuilder`
19+
class :class:`onnx_array_api.graph_api.GraphBuilder`,
20+
`"onnx-short"` replaces long initializer with random values
2021
:return: code
2122
2223
.. runpython::
@@ -84,6 +85,9 @@ class :class:`onnx_array_api.graph_api.GraphBuilder`
8485
if api == "onnx":
8586
tr = Translater(proto, emitter=InnerEmitter())
8687
return tr.export(as_str=True)
88+
if api == "onnx-short":
89+
tr = Translater(proto, emitter=InnerEmitterShortInitializer())
90+
return tr.export(as_str=True)
8791
if api == "builder":
8892
tr = Translater(proto, emitter=BuilderEmitter())
8993
return tr.export(as_str=True)

‎onnx_array_api/translate_api/inner_emitter.py

Copy file name to clipboardExpand all lines: onnx_array_api/translate_api/inner_emitter.py
+55Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
106106
raise NotImplementedError(f"Unexpected dtype={sdtype}.")
107107
else:
108108
sdtype = f"np.{sdtype}"
109+
109110
return [
110111
"initializers.append(",
111112
f" {fra}(",
@@ -209,3 +210,57 @@ def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]:
209210
")",
210211
]
211212
return lines
213+
214+
215+
class InnerEmitterShortInitializer(InnerEmitter):
216+
"""
217+
Converts event into proper code.
218+
Initializer are replaced by random values if too big.
219+
"""
220+
221+
def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
222+
name = kwargs["name"]
223+
value = kwargs["value"]
224+
repl = {"bool": "bool_", "object": "object_", "str": "str_"}
225+
fra = "from_array"
226+
sdtype = repl.get(str(value.dtype), str(value.dtype))
227+
if sdtype.startswith("("):
228+
from onnx.reference.custom_element_types import float8e4m3fn
229+
230+
if sdtype == str(float8e4m3fn):
231+
sdtype = "float8e4m3fn"
232+
fra = "from_array_extended"
233+
else:
234+
raise NotImplementedError(f"Unexpected dtype={sdtype}.")
235+
else:
236+
sdtype = f"np.{sdtype}"
237+
if value.size <= 16:
238+
return [
239+
"initializers.append(",
240+
f" {fra}(",
241+
f" np.array({value.tolist()}, dtype={sdtype}),",
242+
f" name={name!r}",
243+
" )",
244+
")",
245+
]
246+
if "int" in sdtype:
247+
return [
248+
f"value = np.random.randint(0, 10, size={value.shape})"
249+
f".astype({sdtype})",
250+
"initializers.append(",
251+
f" {fra}(",
252+
f" np.array(value, dtype={sdtype}),",
253+
f" name={name!r}",
254+
" )",
255+
")",
256+
]
257+
return [
258+
f"value = np.random.randn({', '.join(map(str,value.shape))})"
259+
f".astype({sdtype})",
260+
"initializers.append(",
261+
f" {fra}(",
262+
f" np.array(value, dtype={sdtype}),",
263+
f" name={name!r}",
264+
" )",
265+
")",
266+
]

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.