Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit dcc2ddd

Browse filesBrowse files
authored
Add discrepancies when comparing the execution of two models (#79)
* update requirements * add discrepancies figures * fix command line * doc
1 parent a906010 commit dcc2ddd
Copy full SHA for dcc2ddd

File tree

4 files changed

+81
-9
lines changed
Filter options

4 files changed

+81
-9
lines changed

‎CHANGELOGS.rst

Copy file name to clipboardExpand all lines: CHANGELOGS.rst
+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Change Logs
55
+++++
66

77
* :pr:`77`: supports ConcatOfShape and Slice with the light API
8-
* :pr:`76`: add a mode to compare models without execution
8+
* :pr:`76`, :pr:`79`: add a mode to compare models without execution
99
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
1010
* :pr:`71`: adds tools to compare two onnx graphs
1111
* :pr:`61`: adds function to plot onnx model as graphs

‎_unittests/ut_reference/test_evaluator_yield.py

Copy file name to clipboardExpand all lines: _unittests/ut_reference/test_evaluator_yield.py
+25
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,31 @@ def test_compare_execution(self):
462462
self.assertIn("CAAA Constant", text)
463463
self.assertEqual(len(align), 5)
464464

465+
def test_compare_execution_discrepancies(self):
466+
m1 = parse_model(
467+
"""
468+
<ir_version: 8, opset_import: [ "": 18]>
469+
agraph (float[N] x) => (float[N] z) {
470+
two = Constant <value_float=2.0> ()
471+
four = Add(two, two)
472+
z = Mul(x, x)
473+
}"""
474+
)
475+
m2 = parse_model(
476+
"""
477+
<ir_version: 8, opset_import: [ "": 18]>
478+
agraph (float[N] x) => (float[N] z) {
479+
two = Constant <value_float=2.0> ()
480+
z = Mul(x, x)
481+
}"""
482+
)
483+
res1, res2, align, dc = compare_onnx_execution(m1, m2, keep_tensor=True)
484+
text = dc.to_str(res1, res2, align)
485+
print(text)
486+
self.assertIn("CAAA Constant", text)
487+
self.assertIn("| a=", text)
488+
self.assertIn(" r=", text)
489+
465490
def test_no_execution(self):
466491
model = make_model(
467492
make_graph(

‎onnx_array_api/_command_lines_parser.py

Copy file name to clipboardExpand all lines: onnx_array_api/_command_lines_parser.py
+12-2
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,15 @@ def get_parser_compare() -> ArgumentParser:
106106
parser.add_argument(
107107
"-c",
108108
"--column-size",
109-
default=50,
109+
default=60,
110110
help="column size when displaying the results",
111111
)
112+
parser.add_argument(
113+
"-d",
114+
"--discrepancies",
115+
default=0,
116+
help="show precise discrepancies when mode is execution",
117+
)
112118
return parser
113119

114120

@@ -120,7 +126,11 @@ def _cmd_compare(argv: List[Any]):
120126
onx1 = onnx.load(args.model1)
121127
onx2 = onnx.load(args.model2)
122128
res1, res2, align, dc = compare_onnx_execution(
123-
onx1, onx2, verbose=args.verbose, mode=args.mode
129+
onx1,
130+
onx2,
131+
verbose=args.verbose,
132+
mode=args.mode,
133+
keep_tensor=args.discrepancies in (1, "1", "True", True),
124134
)
125135
text = dc.to_str(res1, res2, align, column_size=int(args.column_size))
126136
print(text)

‎onnx_array_api/reference/evaluator_yield.py

Copy file name to clipboardExpand all lines: onnx_array_api/reference/evaluator_yield.py
+43-6
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class ResultExecution:
5757
summary: str
5858
op_type: str
5959
name: str
60+
value: Optional[Any] = None
6061

6162
def __len__(self) -> int:
6263
return 6
@@ -122,9 +123,11 @@ def make_summary(value: Any, length: int = 4, modulo: int = 26) -> str:
122123
else:
123124
value2 = value.flatten().astype(np.float64)
124125
value4 = value2.reshape((4, -1)).sum(axis=1)
125-
value4i = value4.astype(np.int64) % modulo
126-
s = "".join([chr(65 + i) for i in value4i])
127-
return s
126+
value4 = np.where(np.abs(value4) < 1e10, value4, np.nan)
127+
s = []
128+
for v in value4:
129+
s.append("?" if np.isnan(v) else (chr(65 + int(v) % modulo)))
130+
return "".join(s)
128131

129132

130133
class YieldEvaluator:
@@ -228,6 +231,7 @@ def enumerate_summarized(
228231
output_names: Optional[List[str]] = None,
229232
feed_inputs: Optional[Dict[str, Any]] = None,
230233
raise_exc: bool = True,
234+
keep_tensor: bool = False,
231235
) -> Iterator[ResultExecution]:
232236
"""
233237
Executes the onnx model and enumerate intermediate results without their names.
@@ -236,17 +240,40 @@ def enumerate_summarized(
236240
:param feed_inputs: dictionary `{ input name: input value }`
237241
:param raise_exc: raises an exception if the execution fails or stop
238242
where it is
243+
:param keep_tensor:keep the tensor in order to compute precise distances
239244
:return: iterator on ResultExecution
240245
"""
241246
for kind, name, value, op_type in self.enumerate_results(
242247
output_names, feed_inputs, raise_exc=raise_exc
243248
):
244249
summary = make_summary(value)
245250
yield ResultExecution(
246-
kind, value.dtype, value.shape, summary, op_type, name
251+
kind,
252+
value.dtype,
253+
value.shape,
254+
summary,
255+
op_type,
256+
name,
257+
value=value if keep_tensor else None,
247258
)
248259

249260

261+
def discrepancies(
262+
expected: np.ndarray, value: np.ndarray, eps: float = 1e-7
263+
) -> Dict[str, float]:
264+
"""
265+
Computes absolute error and relative error between two matrices.
266+
"""
267+
assert (
268+
expected.size == value.size
269+
), f"Incompatible shapes v1.shape={expected.shape}, v2.shape={value.shape}"
270+
expected = expected.ravel().astype(np.float32)
271+
value = value.ravel().astype(np.float32)
272+
diff = np.abs(expected - value)
273+
rel = diff / (np.abs(expected) + eps)
274+
return dict(aerr=float(diff.max()), rerr=float(rel.max()))
275+
276+
250277
class DistanceExecution:
251278
"""
252279
Computes a distance between two results.
@@ -403,6 +430,14 @@ def to_str(
403430
d = self.distance_pair(d1, d2)
404431
symbol = "=" if d == 0 else "~"
405432
line = f"{symbol} | {_align(str(d1), column_size)} | {_align(str(d2), column_size)}"
433+
if (
434+
d1.value is not None
435+
and d2.value is not None
436+
and d1.value.size == d2.value.size
437+
):
438+
disc = discrepancies(d1.value, d2.value)
439+
a, r = disc["aerr"], disc["rerr"]
440+
line += f" | a={a:.3f} r={r:.3f}"
406441
elif i == last[0]:
407442
d2 = s2[j]
408443
line = (
@@ -551,6 +586,7 @@ def compare_onnx_execution(
551586
verbose: int = 0,
552587
raise_exc: bool = True,
553588
mode: str = "execute",
589+
keep_tensor: bool = False,
554590
) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]:
555591
"""
556592
Compares the execution of two onnx models.
@@ -566,6 +602,7 @@ def compare_onnx_execution(
566602
:param raise_exc: raise exception if the execution fails or stop at the error
567603
:param mode: the model should be executed but the function can be executed
568604
but the comparison may append on nodes only
605+
:param keep_tensor: keeps the tensor in order to compute a precise distance
569606
:return: four results, a sequence of results for the first model and the second model,
570607
the alignment between the two, DistanceExecution
571608
"""
@@ -589,15 +626,15 @@ def compare_onnx_execution(
589626
print("[compare_onnx_execution] execute first model")
590627
res1 = list(
591628
YieldEvaluator(model1).enumerate_summarized(
592-
None, feeds1, raise_exc=raise_exc
629+
None, feeds1, raise_exc=raise_exc, keep_tensor=keep_tensor
593630
)
594631
)
595632
if verbose:
596633
print(f"[compare_onnx_execution] got {len(res1)} results")
597634
print("[compare_onnx_execution] execute second model")
598635
res2 = list(
599636
YieldEvaluator(model2).enumerate_summarized(
600-
None, feeds2, raise_exc=raise_exc
637+
None, feeds2, raise_exc=raise_exc, keep_tensor=keep_tensor
601638
)
602639
)
603640
elif mode == "nodes":

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.