diff --git a/onnx_array_api/graph_api/graph_builder.py b/onnx_array_api/graph_api/graph_builder.py index f238eee..c9c2059 100644 --- a/onnx_array_api/graph_api/graph_builder.py +++ b/onnx_array_api/graph_api/graph_builder.py @@ -836,11 +836,12 @@ def remove_identity_nodes(self): """ Removes identity nodes. """ - # f{v} is not possible because of " + f"{v}->{replacements[v]}, old_name={old_name!r}, new_name={new_name!r}" + ) # second pass: replacements in initializer for k, v in replacements.items(): @@ -876,10 +903,12 @@ def remove_identity_nodes(self): repo = {o for o in node.output if o in replacements} repi = {o for o in node.input if o in replacements} if repi or repo: + new_inputs = [replacements.get(i, i) for i in node.input] + new_outputs = [replacements.get(i, i) for i in node.output] new_node = oh.make_node( node.op_type, - [replacements.get(i, i) for i in node.input], - [replacements.get(i, i) for i in node.output], + new_inputs, + new_outputs, domain=node.domain, name=node.name, ) diff --git a/pyproject.toml b/pyproject.toml index 0b0e71d..525b648 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,11 +11,11 @@ exclude = [ # Same as Black. line-length = 88 -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 10 -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "_doc/examples/plot_first_example.py" = ["E402", "F811"] "_doc/examples/plot_onnxruntime.py" = ["E402", "F811"] "onnx_array_api/array_api/_onnx_common.py" = ["F821"]