Skip to content

Commit b802852

Browse files
authored
Update from_onnx_text method (#80)
- Change the initializers parameter to take a list of tensors. - Add the missing tests and expose the `to_onnx_text` method in the root namespace. - Update `create_value_mapping` to include values from subgraphs. --------- Signed-off-by: Justin Chu <[email protected]>
1 parent a6bd55f commit b802852

File tree

4 files changed

+100
-12
lines changed

4 files changed

+100
-12
lines changed

src/onnx_ir/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
"from_proto",
7474
"from_onnx_text",
7575
"to_proto",
76+
"to_onnx_text",
7677
# Convenience constructors
7778
"tensor",
7879
"node",
@@ -149,7 +150,7 @@
149150
TypeProtocol,
150151
ValueProtocol,
151152
)
152-
from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto
153+
from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_onnx_text, to_proto
153154

154155
DEBUG = False
155156

src/onnx_ir/_convenience/__init__.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import onnx
2323

24-
from onnx_ir import _core, _enums, _protocols, serde
24+
from onnx_ir import _core, _enums, _protocols, serde, traversal
2525

2626
SupportedAttrTypes = Union[
2727
str,
@@ -313,7 +313,9 @@ def replace_all_uses_with(
313313
def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
314314
"""Return a dictionary mapping names to values in the graph.
315315
316-
The mapping does not include values from subgraphs.
316+
The mapping includes values from subgraphs. Duplicated names are omitted,
317+
and the first value with that name is returned. Values with empty names
318+
are excluded from the mapping.
317319
318320
Args:
319321
graph: The graph to extract the mapping from.
@@ -327,11 +329,23 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
327329
for input in graph.inputs:
328330
if not input.name:
329331
continue
332+
if input.name in values:
333+
continue
330334
values[input.name] = input
331-
for node in graph:
335+
for node in traversal.RecursiveGraphIterator(graph):
336+
for value in node.inputs:
337+
if not value:
338+
continue
339+
if not value.name:
340+
continue
341+
if value.name in values:
342+
continue
343+
values[value.name] = value
332344
for value in node.outputs:
333345
if not value.name:
334346
continue
347+
if value.name in values:
348+
continue
335349
values[value.name] = value
336350
return values
337351

src/onnx_ir/serde.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
import collections
6464
import logging
6565
import os
66-
from collections.abc import Mapping, Sequence
66+
from collections.abc import Iterable, Mapping, Sequence
6767
from typing import Any, Callable
6868

6969
import numpy as np
@@ -194,30 +194,36 @@ def from_proto(proto: object) -> object:
194194
def from_onnx_text(
195195
model_text: str,
196196
/,
197-
with_initializers: Mapping[str, _protocols.TensorProtocol] | None = None,
197+
initializers: Iterable[_protocols.TensorProtocol] | None = None,
198198
) -> _core.Model:
199199
"""Convert the ONNX textual representation to an IR model.
200200
201201
Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html
202202
203203
Args:
204204
model_text: The ONNX textual representation of the model.
205-
with_initializers: A mapping of initializer names to tensors. If provided, these tensors
205+
initializers: Tensors to be added as initializers. If provided, these tensors
206206
will be added to the model as initializers. If a name does not exist in the model,
207207
a ValueError will be raised.
208208
209209
Returns:
210210
The IR model corresponding to the ONNX textual representation.
211211
212212
Raises:
213-
ValueError: If a name in `with_initializers` does not exist in the model.
213+
ValueError: If a tensor name in `initializers` does not match any value in the model.
214214
"""
215215
proto = onnx.parser.parse_model(model_text)
216216
model = deserialize_model(proto)
217217
values = _convenience.create_value_mapping(model.graph)
218-
if with_initializers:
218+
if initializers:
219219
# Add initializers to the model
220-
for name, tensor in with_initializers.items():
220+
for tensor in initializers:
221+
name = tensor.name
222+
if not name:
223+
raise ValueError(
224+
"Initializer tensor must have a name. "
225+
f"Please provide a name for the initializer: {tensor}"
226+
)
221227
if name not in values:
222228
raise ValueError(f"Value '{name}' does not exist in model.")
223229
initializer = values[name]

src/onnx_ir/serde_test.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_to_proto(self, _: str, ir_object):
5757
def test_from_to_onnx_text(self):
5858
model_text = """\
5959
<
60-
ir_version: 7,
60+
ir_version: 10,
6161
opset_import: ["" : 17]
6262
>
6363
agraph (float[1,4,512,512] input_x, float[1,4,512,64] input_y) => (float[4,512,512] reshape_x) {
@@ -67,12 +67,79 @@ def test_from_to_onnx_text(self):
6767
self.maxDiff = None
6868
model = serde.from_onnx_text(model_text)
6969
self.assertIsInstance(model, ir.Model)
70-
self.assertEqual(model.ir_version, 7)
70+
self.assertEqual(model.ir_version, 10)
7171
self.assertEqual(len(model.graph.inputs), 2)
7272
self.assertEqual(len(model.graph.outputs), 1)
7373
onnx_text_roundtrip = serde.to_onnx_text(model)
7474
self.assertEqual(model_text.strip(), onnx_text_roundtrip.strip())
7575

76+
def test_from_to_onnx_text_with_initializers(self):
77+
model_text = """\
78+
<
79+
ir_version: 10,
80+
opset_import: ["" : 17]
81+
>
82+
agraph (float[1] input_x, float[2] input_y) => (float[2] result) {
83+
[node_1] add = Add (input_x, input_y)
84+
[node_2] result = Add (add, initializer_z)
85+
}"""
86+
self.maxDiff = None
87+
array = np.array([1.0, 2.0], dtype=np.float32)
88+
init_array = np.array([3.0, 4.0], dtype=np.float32)
89+
model = serde.from_onnx_text(
90+
model_text,
91+
initializers=[
92+
ir.tensor(init_array, name="initializer_z"),
93+
ir.tensor(array, name="input_y"),
94+
],
95+
)
96+
np.testing.assert_array_equal(model.graph.inputs[1].const_value.numpy(), array)
97+
np.testing.assert_array_equal(
98+
model.graph.initializers["initializer_z"].const_value.numpy(), init_array
99+
)
100+
expected_text = """\
101+
<
102+
ir_version: 10,
103+
opset_import: ["" : 17]
104+
>
105+
agraph (float[1] input_x, float[2] input_y) => (float[2] result)
106+
<float[2] initializer_z = {3,4}, float[2] input_y = {1,2}>
107+
{
108+
[node_1] add = Add (input_x, input_y)
109+
[node_2] result = Add (add, initializer_z)
110+
}"""
111+
onnx_text_roundtrip = serde.to_onnx_text(model)
112+
stripped_lines = [line.rstrip() for line in onnx_text_roundtrip.splitlines()]
113+
result = "\n".join(stripped_lines)
114+
self.assertEqual(result, expected_text)
115+
116+
def test_to_onnx_text_excluding_initializers(self):
117+
model_text = """\
118+
<
119+
ir_version: 10,
120+
opset_import: ["" : 17]
121+
>
122+
agraph (float[1] input_x, float[2] input_y) => (float[2] result) {
123+
[node_name] result = Add (input_x, input_y)
124+
}"""
125+
self.maxDiff = None
126+
array = np.array([1.0, 2.0], dtype=np.float32)
127+
model = serde.from_onnx_text(
128+
model_text, initializers=[ir.tensor(array, name="input_y")]
129+
)
130+
onnx_text_without_initializers = serde.to_onnx_text(model, exclude_initializers=True)
131+
expected_text_without_initializers = """\
132+
<
133+
ir_version: 10,
134+
opset_import: ["" : 17]
135+
>
136+
agraph (float[1] input_x, float[2] input_y) => (float[2] result) {
137+
[node_name] result = Add (input_x, input_y)
138+
}"""
139+
self.assertEqual(
140+
onnx_text_without_initializers.strip(), expected_text_without_initializers
141+
)
142+
76143

77144
class TensorProtoTensorTest(unittest.TestCase):
78145
@parameterized.parameterized.expand(

0 commit comments

Comments
 (0)