Skip to content

Commit a6bd55f

Browse files
authored
Implement all_nodes() and subgraphs() as convenience methods (#76)
Suggested by @xiaoyu-work, this PR created convenience methods all_nodes() and subgraphs() for easier access to the nested nodes and subgraphs. They are created so that it is easier to discover the methods. --------- Signed-off-by: Justin Chu <[email protected]>
1 parent 4a816f3 commit a6bd55f

File tree

2 files changed

+113
-6
lines changed

2 files changed

+113
-6
lines changed

src/onnx_ir/_core.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2366,6 +2366,28 @@ def num_nodes(self) -> int:
23662366
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
23672367
return len(self)
23682368

2369+
def all_nodes(self) -> Iterator[Node]:
2370+
"""Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
2371+
2372+
This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
2373+
Consider using
2374+
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
2375+
traversals on nodes.
2376+
"""
2377+
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
2378+
return onnx_ir.traversal.RecursiveGraphIterator(self)
2379+
2380+
def subgraphs(self) -> Iterator[Graph]:
2381+
"""Get all subgraphs in the graph in O(#nodes + #attributes) time."""
2382+
seen_graphs: set[Graph] = set()
2383+
for node in onnx_ir.traversal.RecursiveGraphIterator(self):
2384+
graph = node.graph
2385+
if graph is self:
2386+
continue
2387+
if graph is not None and graph not in seen_graphs:
2388+
seen_graphs.add(graph)
2389+
yield graph
2390+
23692391
# Mutation methods
23702392
def append(self, node: Node, /) -> None:
23712393
"""Append a node to the graph in O(1) time.
@@ -2871,7 +2893,7 @@ def graphs(self) -> Iterable[Graph]:
28712893
"""Get all graphs and subgraphs in the model.
28722894
28732895
This is a convenience method to traverse the model. Consider using
2874-
`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
2896+
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
28752897
traversals on nodes.
28762898
"""
28772899
# NOTE(justinchuby): Given
@@ -2880,11 +2902,8 @@ def graphs(self) -> Iterable[Graph]:
28802902
# (3) Users familiar with onnxruntime optimization tools expect this method
28812903
# I created this method as a core method instead of an iterator in
28822904
# `traversal.py`.
2883-
seen_graphs: set[Graph] = set()
2884-
for node in onnx_ir.traversal.RecursiveGraphIterator(self.graph):
2885-
if node.graph is not None and node.graph not in seen_graphs:
2886-
seen_graphs.add(node.graph)
2887-
yield node.graph
2905+
yield self.graph
2906+
yield from self.graph.subgraphs()
28882907

28892908

28902909
class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
@@ -3018,6 +3037,28 @@ def meta(self) -> _metadata.MetadataStore:
30183037
def metadata_props(self) -> dict[str, str]:
30193038
return self._graph.metadata_props
30203039

3040+
def all_nodes(self) -> Iterator[Node]:
3041+
"""Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
3042+
3043+
This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
3044+
Consider using
3045+
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
3046+
traversals on nodes.
3047+
"""
3048+
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
3049+
return onnx_ir.traversal.RecursiveGraphIterator(self)
3050+
3051+
def subgraphs(self) -> Iterator[Graph]:
3052+
"""Get all subgraphs in the function in O(#nodes + #attributes) time."""
3053+
seen_graphs: set[Graph] = set()
3054+
for node in onnx_ir.traversal.RecursiveGraphIterator(self):
3055+
graph = node.graph
3056+
if graph is self._graph:
3057+
continue
3058+
if graph is not None and graph not in seen_graphs:
3059+
seen_graphs.add(graph)
3060+
yield graph
3061+
30213062
# Mutation methods
30223063
def append(self, node: Node, /) -> None:
30233064
"""Append a node to the function in O(1) time."""

src/onnx_ir/_core_test.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,72 @@ def test_topological_sort_subgraph(self):
12761276
("d", "c", "b", "a", ">", "if"),
12771277
)
12781278

1279+
def test_all_nodes_returns_all_nodes(self):
1280+
# Create a graph with a subgraph
1281+
v0 = _core.Value(name="v0")
1282+
v1 = _core.Value(name="v1")
1283+
node0 = _core.Node("", "A", inputs=(v0,), num_outputs=1)
1284+
node1 = _core.Node("", "B", inputs=(v1,), num_outputs=1)
1285+
sub_node = _core.Node(
1286+
"", "Sub", inputs=(node0.outputs[0], node1.outputs[0]), num_outputs=1
1287+
)
1288+
subgraph = _core.Graph(
1289+
inputs=(), outputs=(sub_node.outputs[0],), nodes=(sub_node,), name="subgraph"
1290+
)
1291+
main_node = _core.Node(
1292+
"",
1293+
"If",
1294+
inputs=(node0.outputs[0],),
1295+
attributes=[ir.AttrGraph("then_branch", subgraph)],
1296+
)
1297+
graph = _core.Graph(
1298+
inputs=(v0, v1),
1299+
outputs=(main_node.outputs[0],),
1300+
nodes=(node0, node1, main_node),
1301+
name="main_graph",
1302+
)
1303+
all_nodes = list(graph.all_nodes())
1304+
# Should include node0, node1, main_node, and sub_node
1305+
self.assertIn(node0, all_nodes)
1306+
self.assertIn(node1, all_nodes)
1307+
self.assertIn(main_node, all_nodes)
1308+
self.assertIn(sub_node, all_nodes)
1309+
self.assertEqual(len(all_nodes), 4)
1310+
1311+
def test_subgraphs_returns_all_subgraphs(self):
1312+
# Create a graph with two subgraphs
1313+
v0 = _core.Value(name="v0")
1314+
v1 = _core.Value(name="v1")
1315+
node0 = _core.Node("", "A", inputs=(v0,), num_outputs=1)
1316+
node1 = _core.Node("", "B", inputs=(v1,), num_outputs=1)
1317+
sub_node1 = _core.Node("", "Sub1", inputs=(node0.outputs[0],), num_outputs=1)
1318+
sub_node2 = _core.Node("", "Sub2", inputs=(node1.outputs[0],), num_outputs=1)
1319+
subgraph1 = _core.Graph(
1320+
inputs=(), outputs=(sub_node1.outputs[0],), nodes=(sub_node1,), name="subgraph1"
1321+
)
1322+
subgraph2 = _core.Graph(
1323+
inputs=(), outputs=(sub_node2.outputs[0],), nodes=(sub_node2,), name="subgraph2"
1324+
)
1325+
main_node = _core.Node(
1326+
"",
1327+
"If",
1328+
inputs=(node0.outputs[0],),
1329+
attributes=[
1330+
ir.AttrGraph("then_branch", subgraph1),
1331+
ir.AttrGraph("else_branch", subgraph2),
1332+
],
1333+
)
1334+
graph = _core.Graph(
1335+
inputs=(v0, v1),
1336+
outputs=(main_node.outputs[0],),
1337+
nodes=(node0, node1, main_node),
1338+
name="main_graph",
1339+
)
1340+
subgraphs = list(graph.subgraphs())
1341+
self.assertIn(subgraph1, subgraphs)
1342+
self.assertIn(subgraph2, subgraphs)
1343+
self.assertEqual(len(subgraphs), 2)
1344+
12791345

12801346
class GraphContainersTest(unittest.TestCase):
12811347
"""Test containers for input, output and initializers of a graph."""

0 commit comments

Comments
 (0)