Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 2dd0686

Browse filesBrowse files
authored
Add ConstantOfShape to light API (#77)
* update requirements * Add ConstantOfShape to light API
1 parent 4cf9dcc commit 2dd0686
Copy full SHA for 2dd0686

File tree

4 files changed

+28
-2
lines changed
Filter options

4 files changed

+28
-2
lines changed

‎_unittests/ut_light_api/test_light_api.py

Copy file name to clipboardExpand all lines: _unittests/ut_light_api/test_light_api.py
+13-1Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import unittest
33
from typing import Callable, Optional
44
import numpy as np
5-
from onnx import GraphProto, ModelProto
5+
from onnx import GraphProto, ModelProto, TensorProto
66
from onnx.defs import (
77
get_all_schemas_with_history,
88
onnx_opset_version,
@@ -526,6 +526,18 @@ def test_input_shape(self):
526526
i = str(model.graph.input[0]).replace("\n", "").replace(" ", "")
527527
self.assertNotIn("shape{}", i)
528528

529+
def test_constant_of_shape(self):
530+
onx = (
531+
start()
532+
.vin("X", TensorProto.INT64, shape=[None, None])
533+
.ConstantOfShape()
534+
.vout(shape=[])
535+
.to_onnx()
536+
)
537+
ref = ReferenceEvaluator(onx)
538+
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
539+
self.assertEqualArray(np.zeros((2, 3), dtype=np.float32), got)
540+
529541

530542
if __name__ == "__main__":
531543
TestLightApi().test_add()

‎onnx_array_api/light_api/__init__.py

Copy file name to clipboardExpand all lines: onnx_array_api/light_api/__init__.py
+3-1Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
def start(
99
opset: Optional[int] = None,
1010
opsets: Optional[Dict[str, int]] = None,
11+
ir_version: Optional[int] = None,
1112
) -> OnnxGraph:
1213
"""
1314
Starts an onnx model.
1415
1516
:param opset: main opset version
1617
:param opsets: others opsets as a dictionary
18+
:param ir_version: specify the ir_version as well
1719
:return: an instance of :class:`onnx_array_api.light_api.OnnxGraph`
1820
1921
A very simple model:
@@ -45,7 +47,7 @@ def start(
4547
)
4648
print(onx)
4749
"""
48-
return OnnxGraph(opset=opset, opsets=opsets)
50+
return OnnxGraph(opset=opset, opsets=opsets, ir_version=ir_version)
4951

5052

5153
def g() -> OnnxGraph:

‎onnx_array_api/light_api/_op_var.py

Copy file name to clipboardExpand all lines: onnx_array_api/light_api/_op_var.py
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from typing import List, Optional, Union
2+
import numpy as np
3+
from ..reference import from_array_extended
24
from ..annotations import AI_ONNX_ML, domain
35

46

@@ -69,6 +71,11 @@ def Cast(self, saturate: int = 1, to: int = 0) -> "Var":
6971
def Celu(self, alpha: float = 1.0) -> "Var":
7072
return self.make_node("Celu", self, alpha=alpha)
7173

74+
def ConstantOfShape(self, value: Optional[np.array] = None) -> "Var":
75+
if value is None:
76+
return self.make_node("ConstantOfShape", self)
77+
return self.make_node("ConstantOfShape", self, value=from_array_extended(value))
78+
7279
def DepthToSpace(self, blocksize: int = 0, mode: str = "DCR") -> "Var":
7380
return self.make_node("DepthToSpace", self, blocksize=blocksize, mode=mode)
7481

‎onnx_array_api/light_api/model.py

Copy file name to clipboardExpand all lines: onnx_array_api/light_api/model.py
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@ class OnnxGraph:
4242
4343
:param opset: main opset version
4444
:param opsets: other opsets as a dictionary
45+
:param ir_version: to specify an ir_version
4546
:param is_function: a :class:`onnx.ModelProto` or a :class:`onnx.FunctionProto`
4647
"""
4748

4849
def __init__(
4950
self,
5051
opset: Optional[int] = None,
5152
opsets: Optional[Dict[str, int]] = None,
53+
ir_version: Optional[int] = None,
5254
proto_type: ProtoType = ProtoType.MODEL,
5355
):
5456
if opsets is not None and "" in opsets:
@@ -65,6 +67,7 @@ def __init__(
6567
self.proto_type = proto_type
6668
self.opsets = opsets
6769
self.opset = opset
70+
self.ir_version = ir_version
6871
self.nodes: List[Union[NodeProto, TensorProto]] = []
6972
self.inputs: List[ValueInfoProto] = []
7073
self.outputs: List[ValueInfoProto] = []
@@ -402,6 +405,8 @@ def to_onnx(self) -> GRAPH_PROTO:
402405
# If no opsets, it a subgraph, not a model.
403406
return graph
404407
model = make_model(graph, opset_imports=opsets)
408+
if self.ir_version:
409+
model.ir_version = ir_version
405410
if not is_windows() or not is_azure():
406411
# check_model fails sometimes on Windows
407412
check_model(model)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.