Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 7675869

Browse filesBrowse files
authored
Extend ExtendedReferenceEvaluator (#75)
* update requirements * add more operator to the reference evaluator * extend unit test copverage
1 parent a070da3 commit 7675869
Copy full SHA for 7675869

File tree

6 files changed

+222
-0
lines changed
Filter options

6 files changed

+222
-0
lines changed

‎CHANGELOGS.rst

Copy file name to clipboardExpand all lines: CHANGELOGS.rst
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
78
* :pr:`71`: adds tools to compare two onnx graphs
89
* :pr:`61`: adds function to plot onnx model as graphs
910
* :pr:`60`: supports translation of local functions

‎_unittests/ut_reference/test_reference_ops.py

Copy file name to clipboardExpand all lines: _unittests/ut_reference/test_reference_ops.py
+82Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,88 @@ def test_fused_matmul11(self):
5959
got = ref.run(None, {"X": a, "Y": a})
6060
self.assertEqualArray(a.T @ a.T, got[0])
6161

62+
def test_memcpy(self):
63+
model = make_model(
64+
make_graph(
65+
[
66+
make_node("MemcpyToHost", ["X"], ["Z"]),
67+
make_node("MemcpyFromHost", ["X"], ["Z"]),
68+
],
69+
"name",
70+
[make_tensor_value_info("X", TensorProto.FLOAT, None)],
71+
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
72+
),
73+
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
74+
ir_version=9,
75+
)
76+
a = np.arange(4).reshape(-1, 2).astype(np.float32)
77+
ref = ExtendedReferenceEvaluator(model)
78+
got = ref.run(None, {"X": a})
79+
self.assertEqualArray(a, got[0])
80+
81+
def test_quick_gelu(self):
82+
from onnxruntime import InferenceSession
83+
84+
for alpha in [0.0, 2.0]:
85+
model = make_model(
86+
make_graph(
87+
[
88+
make_node(
89+
"QuickGelu",
90+
["X"],
91+
["Z"],
92+
domain="com.microsoft",
93+
alpha=alpha,
94+
)
95+
],
96+
"name",
97+
[make_tensor_value_info("X", TensorProto.FLOAT, None)],
98+
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
99+
),
100+
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
101+
ir_version=9,
102+
)
103+
sess = InferenceSession(
104+
model.SerializeToString(), providers=["CPUExecutionProvider"]
105+
)
106+
a = np.arange(4).reshape(-1, 2).astype(np.float32)
107+
expected = sess.run(None, {"X": a})
108+
ref = ExtendedReferenceEvaluator(model)
109+
got = ref.run(None, {"X": a})
110+
self.assertEqualArray(expected[0], got[0])
111+
112+
def test_scatter_elements(self):
113+
model = make_model(
114+
make_graph(
115+
[
116+
make_node(
117+
"ScatterElements",
118+
["data", "indices", "updates"],
119+
["Z"],
120+
axis=3,
121+
reduction="add",
122+
)
123+
],
124+
"name",
125+
[
126+
make_tensor_value_info("data", TensorProto.FLOAT, None),
127+
make_tensor_value_info("indices", TensorProto.INT64, None),
128+
make_tensor_value_info("updates", TensorProto.FLOAT, None),
129+
],
130+
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
131+
),
132+
opset_imports=[make_opsetid("", 18)],
133+
)
134+
data = np.zeros(2**4, dtype=np.float32).reshape((2, 2, 2, 2))
135+
indices = np.array([[[[0]]]], dtype=np.int64)
136+
updates = np.array([[[[1]]]], dtype=np.float32)
137+
y = np.array(
138+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32
139+
).reshape((2, 2, 2, 2))
140+
ref = ExtendedReferenceEvaluator(model)
141+
got = ref.run(None, {"data": data, "indices": indices, "updates": updates})
142+
self.assertEqualArray(y, got[0])
143+
62144

63145
if __name__ == "__main__":
64146
unittest.main(verbosity=2)

‎onnx_array_api/reference/evaluator.py

Copy file name to clipboardExpand all lines: onnx_array_api/reference/evaluator.py
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from .ops.op_concat import Concat
99
from .ops.op_constant_of_shape import ConstantOfShape
1010
from .ops.op_fused_matmul import FusedMatMul
11+
from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
12+
from .ops.op_quick_gelu import QuickGelu
13+
from .ops.op_scatter_elements import ScatterElements
1114

1215

1316
logger = getLogger("onnx-array-api-eval")
@@ -34,6 +37,10 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
3437
CastLike_19,
3538
ConstantOfShape,
3639
FusedMatMul,
40+
MemcpyFromHost,
41+
MemcpyToHost,
42+
QuickGelu,
43+
ScatterElements,
3744
]
3845

3946
@staticmethod
+11Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from onnx.reference.op_run import OpRun
2+
3+
4+
class MemcpyFromHost(OpRun):
5+
def _run(self, x):
6+
return (x,)
7+
8+
9+
class MemcpyToHost(OpRun):
10+
def _run(self, x):
11+
return (x,)
+23Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import numpy as np
2+
from onnx.reference.op_run import OpRun
3+
4+
5+
def sigmoid(x): # type: ignore
6+
if x > 0:
7+
return 1 / (1 + np.exp(-x))
8+
return np.exp(x) / (1 + np.exp(x))
9+
10+
11+
class QuickGelu(OpRun):
12+
op_domain = "com.microsoft"
13+
14+
def __init__(self, onnx_node, run_params): # type: ignore
15+
OpRun.__init__(self, onnx_node, run_params)
16+
self.vf = np.vectorize(sigmoid)
17+
18+
def _run(self, X, alpha=1.0):
19+
if len(X.shape) == 0:
20+
return ((X * sigmoid(X * alpha)).astype(X.dtype),)
21+
if X.size == 0:
22+
return (X,)
23+
return ((X * self.vf(X * alpha)).astype(X.dtype),)
+98Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import numpy as np
2+
3+
from onnx.reference.op_run import OpRun
4+
5+
6+
def scatter_elements(data, indices, updates, axis=0, reduction=None): # type: ignore
7+
if reduction == "add":
8+
9+
def f(x, y):
10+
return x + y
11+
12+
elif reduction == "min":
13+
14+
def f(x, y):
15+
return min(x, y)
16+
17+
elif reduction == "max":
18+
19+
def f(x, y):
20+
return max(x, y)
21+
22+
else:
23+
24+
def f(x, y):
25+
return y
26+
27+
if axis < 0:
28+
axis = data.ndim + axis
29+
30+
if len(data.shape) == 1 and axis == 0:
31+
scattered = np.copy(data)
32+
for pos, up in zip(indices, updates):
33+
scattered[pos] = f(scattered[pos], up)
34+
return scattered
35+
36+
if len(indices.shape) == 2:
37+
scattered = np.copy(data)
38+
if axis == 0:
39+
for i in range(indices.shape[0]):
40+
for j in range(indices.shape[1]):
41+
scattered[indices[i, j], j] = f(
42+
scattered[indices[i, j], j], updates[i, j]
43+
)
44+
else:
45+
for i in range(indices.shape[0]):
46+
for j in range(indices.shape[1]):
47+
scattered[i, indices[i, j]] = f(
48+
scattered[i, indices[i, j]], updates[i, j]
49+
)
50+
return scattered
51+
52+
if len(indices.shape) == 3:
53+
scattered = np.copy(data)
54+
if axis == 0:
55+
for i in range(indices.shape[0]):
56+
for j in range(indices.shape[1]):
57+
for k in range(indices.shape[2]):
58+
scattered[indices[i, j, k], j, k] = f(
59+
scattered[indices[i, j, k], j, k], updates[i, j, k]
60+
)
61+
elif axis == 1:
62+
for i in range(indices.shape[0]):
63+
for j in range(indices.shape[1]):
64+
for k in range(indices.shape[2]):
65+
scattered[i, indices[i, j, k], k] = f(
66+
scattered[i, indices[i, j, k], k], updates[i, j, k]
67+
)
68+
elif axis == 2:
69+
for i in range(indices.shape[0]):
70+
for j in range(indices.shape[1]):
71+
for k in range(indices.shape[2]):
72+
scattered[i, j, indices[i, j, k]] = f(
73+
scattered[i, j, indices[i, j, k]], updates[i, j, k]
74+
)
75+
return scattered
76+
77+
if len(indices.shape) == 4:
78+
scattered = np.copy(data)
79+
if axis == 3:
80+
for a in range(indices.shape[0]):
81+
for i in range(indices.shape[1]):
82+
for j in range(indices.shape[2]):
83+
for k in range(indices.shape[3]):
84+
scattered[a, i, j, indices[a, i, j, k]] = f(
85+
scattered[a, i, j, indices[a, i, j, k]],
86+
updates[a, i, j, k],
87+
)
88+
return scattered
89+
90+
raise RuntimeError(
91+
f"Not implemented for indices.shape={indices.shape} and axis={axis}"
92+
)
93+
94+
95+
class ScatterElements(OpRun):
96+
def _run(self, data, indices, updates, axis=None, reduction=None): # type: ignore
97+
res = scatter_elements(data, indices, updates, axis=axis, reduction=reduction)
98+
return (res,)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.