Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6ed1d1c

Browse filesBrowse files
authored
Fix bugs in remove_identity (#70)
* update requirements * fix bugs in remove_identity nodes
1 parent ad22d16 commit 6ed1d1c
Copy full SHA for 6ed1d1c

File tree

2 files changed

+37
-8
lines changed
Filter options

2 files changed

+37
-8
lines changed

‎onnx_array_api/graph_api/graph_builder.py

Copy file name to clipboardExpand all lines: onnx_array_api/graph_api/graph_builder.py
+35-6Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -836,30 +836,57 @@ def remove_identity_nodes(self):
836836
"""
837837
Removes identity nodes.
838838
"""
839-
# f<irst pass: detect replacements
839+
# first pass: detect replacements
840840
new_nodes = []
841841
input_names = set(i.name for i in self.inputs)
842842
output_names = set(i.name for i in self.outputs)
843843
replacements = {}
844+
replacements_rev = {}
844845
for node in self.nodes:
845846
if node.op_type != "Identity":
846847
new_nodes.append(node)
847848
continue
848849

849850
if node.output[0] not in output_names:
850851
old_name, new_name = node.output[0], node.input[0]
851-
elif node.input[0] not in input_names:
852+
elif (
853+
node.input[0] not in input_names
854+
and node.input[0] not in output_names
855+
and node.input[0] not in replacements
856+
):
852857
old_name, new_name = node.input[0], node.output[0]
853858
else:
854859
new_nodes.append(node)
855860
continue
856861

857862
# the new name can be set for replacements as well
858-
assert old_name not in replacements
859863
if new_name in replacements:
860864
new_name = replacements[new_name]
861-
assert new_name not in replacements
865+
assert new_name not in replacements, (
866+
f"Name {old_name!r} still in {replacements}, node.op_type={node.op_type!r}, "
867+
f"node.input={node.input}, node.output={node.output}, "
868+
f"input_names={input_names}, output_names={output_names}"
869+
)
870+
if old_name in replacements_rev:
871+
old_old_name = replacements_rev[old_name]
872+
replacements[old_old_name] = new_name
873+
replacements_rev[new_name] = old_old_name
874+
if old_name in replacements:
875+
replacements[replacements[old_name]] = new_name
876+
assert new_name not in replacements, (
877+
f"Name {old_name!r} still in {replacements}, node.op_type={node.op_type!r}, "
878+
f"node.input={node.input}, node.output={node.output}, "
879+
f"input_names={input_names}, output_names={output_names}"
880+
)
862881
replacements[old_name] = new_name
882+
replacements_rev[new_name] = old_name
883+
884+
# verification
885+
for k, v in replacements.items():
886+
assert v not in replacements, (
887+
f"replacement {k}->{v} is not possible because of "
888+
f"{v}->{replacements[v]}, old_name={old_name!r}, new_name={new_name!r}"
889+
)
863890

864891
# second pass: replacements in initializer
865892
for k, v in replacements.items():
@@ -876,10 +903,12 @@ def remove_identity_nodes(self):
876903
repo = {o for o in node.output if o in replacements}
877904
repi = {o for o in node.input if o in replacements}
878905
if repi or repo:
906+
new_inputs = [replacements.get(i, i) for i in node.input]
907+
new_outputs = [replacements.get(i, i) for i in node.output]
879908
new_node = oh.make_node(
880909
node.op_type,
881-
[replacements.get(i, i) for i in node.input],
882-
[replacements.get(i, i) for i in node.output],
910+
new_inputs,
911+
new_outputs,
883912
domain=node.domain,
884913
name=node.name,
885914
)

‎pyproject.toml

Copy file name to clipboardExpand all lines: pyproject.toml
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ exclude = [
1111
# Same as Black.
1212
line-length = 88
1313

14-
[tool.ruff.mccabe]
14+
[tool.ruff.lint.mccabe]
1515
# Unlike Flake8, default to a complexity level of 10.
1616
max-complexity = 10
1717

18-
[tool.ruff.per-file-ignores]
18+
[tool.ruff.lint.per-file-ignores]
1919
"_doc/examples/plot_first_example.py" = ["E402", "F811"]
2020
"_doc/examples/plot_onnxruntime.py" = ["E402", "F811"]
2121
"onnx_array_api/array_api/_onnx_common.py" = ["F821"]

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.