Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit a906010

Browse filesBrowse files
authored
Documentation (#78)
* update requirements * Add ConstantOfShape to light API * add slice * changelogs * k
1 parent 2dd0686 commit a906010
Copy full SHA for a906010

File tree

4 files changed

+38
-2
lines changed
Filter options

4 files changed

+38
-2
lines changed

‎CHANGELOGS.rst

Copy file name to clipboardExpand all lines: CHANGELOGS.rst
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`77`: supports ConcatOfShape and Slice with the light API
78
* :pr:`76`: add a mode to compare models without execution
89
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
910
* :pr:`71`: adds tools to compare two onnx graphs

‎_unittests/ut_light_api/test_light_api.py

Copy file name to clipboardExpand all lines: _unittests/ut_light_api/test_light_api.py
+29-1Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,35 @@ def test_constant_of_shape(self):
538538
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
539539
self.assertEqualArray(np.zeros((2, 3), dtype=np.float32), got)
540540

541+
def test_constant_of_shape_value(self):
542+
onx = (
543+
start()
544+
.vin("X", TensorProto.INT64, shape=[None, None])
545+
.ConstantOfShape(value=np.array([1], dtype=np.float32))
546+
.vout(shape=[])
547+
.to_onnx()
548+
)
549+
ref = ReferenceEvaluator(onx)
550+
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
551+
self.assertEqualArray(np.ones((2, 3), dtype=np.float32), got)
552+
553+
def test_slice(self):
554+
onx = (
555+
start(opset=18, ir_version=9)
556+
.cst(np.array([1], dtype=np.int64), name="one")
557+
.cst(np.array([2], dtype=np.int64), name="two")
558+
.vin("X", TensorProto.INT64, shape=[None, None])
559+
.ConstantOfShape(value=np.array([1], dtype=np.float32))
560+
.rename("CX")
561+
.bring("CX", "one", "two", "one")
562+
.Slice()
563+
.vout(shape=[])
564+
.to_onnx()
565+
)
566+
ref = ReferenceEvaluator(onx)
567+
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
568+
self.assertEqualArray(np.ones((2, 1), dtype=np.float32), got)
569+
541570

542571
if __name__ == "__main__":
543-
TestLightApi().test_add()
544572
unittest.main(verbosity=2)

‎onnx_array_api/light_api/_op_var.py

Copy file name to clipboardExpand all lines: onnx_array_api/light_api/_op_var.py
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,13 @@ def Selu(
314314
def Shrink(self, bias: float = 0.0, lambd: float = 0.5) -> "Var":
315315
return self.make_node("Shrink", self, bias=bias, lambd=lambd)
316316

317+
def Slice(
318+
self, starts: "Var", ends: "Var", axes: "Var", steps: Optional["Var"] = None
319+
) -> "Var":
320+
if steps is None:
321+
return self.make_node("Slice", self, starts, ends, axes)
322+
return self.make_node("Slice", self, starts, ends, axes, steps)
323+
317324
def Softmax(self, axis: int = -1) -> "Var":
318325
return self.make_node("Softmax", self, axis=axis)
319326

‎onnx_array_api/light_api/model.py

Copy file name to clipboardExpand all lines: onnx_array_api/light_api/model.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def to_onnx(self) -> GRAPH_PROTO:
406406
return graph
407407
model = make_model(graph, opset_imports=opsets)
408408
if self.ir_version:
409-
model.ir_version = ir_version
409+
model.ir_version = self.ir_version
410410
if not is_windows() or not is_azure():
411411
# check_model fails sometimes on Windows
412412
check_model(model)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.