Skip to content

Commit 1d0dc80

Browse files
authored
External data callback and ir.save fixes (#85)
This pull request introduces a new feature to the ONNX IR library by adding support for a callback function in several methods. The callback allows users to debug or log information about tensors being saved to external data files (or enable a progress bar). Additionally, minor improvements to code organization and comments are included. ![image](https://github.com/user-attachments/assets/0ced1a65-5235-469d-8053-2fa93116986a) --------- Signed-off-by: Justin Chu <[email protected]>
1 parent b802852 commit 1d0dc80

File tree

2 files changed

+82
-6
lines changed

2 files changed

+82
-6
lines changed

src/onnx_ir/_io.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
__all__ = ["load", "save"]
88

99
import os
10+
from typing import Callable
1011

1112
import onnx
1213

13-
from onnx_ir import _core, serde
14+
from onnx_ir import _core, _protocols, serde
1415
from onnx_ir import external_data as _external_data
1516
from onnx_ir._polyfill import zip
1617

@@ -43,6 +44,8 @@ def save(
4344
format: str | None = None,
4445
external_data: str | os.PathLike | None = None,
4546
size_threshold_bytes: int = 256,
47+
callback: Callable[[_protocols.TensorProtocol, _external_data.CallbackInfo], None]
48+
| None = None,
4649
) -> None:
4750
"""Save an ONNX model to a file.
4851
@@ -52,6 +55,30 @@ def save(
5255
to load the newly saved model, or provide a different external data path that
5356
is not currently referenced by any tensors in the model.
5457
58+
.. tip::
59+
60+
A simple progress bar can be implemented by passing a callback function as the following::
61+
62+
import onnx_ir as ir
63+
import tqdm
64+
65+
with tqdm.tqdm() as pbar:
66+
total_set = False
67+
68+
def callback(tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo) -> None:
69+
nonlocal total_set
70+
if not total_set:
71+
pbar.total = metadata.total
72+
total_set = True
73+
74+
pbar.update()
75+
pbar.set_description(f"Saving {tensor.name} ({tensor.dtype}, {tensor.shape}) at offset {metadata.offset}")
76+
77+
ir.save(
78+
...,
79+
callback=callback,
80+
)
81+
5582
Args:
5683
model: The model to save.
5784
path: The path to save the model to. E.g. "model.onnx".
@@ -65,6 +92,8 @@ def save(
6592
it will be serialized in the ONNX Proto message.
6693
size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
6794
Effective only when ``external_data`` is set.
95+
callback: A callback function that is called for each tensor that is saved to external data
96+
for debugging or logging purposes.
6897
6998
Raises:
7099
ValueError: If the external data path is an absolute path.
@@ -77,12 +106,19 @@ def save(
77106
base_dir = os.path.dirname(path)
78107

79108
# Store the original initializer values so they can be restored if modify_model=False
80-
initializer_values = tuple(model.graph.initializers.values())
109+
initializer_values: list[_core.Value] = []
110+
for graph in model.graphs():
111+
# Collect from all subgraphs as well
112+
initializer_values.extend(graph.initializers.values())
81113
tensors = [v.const_value for v in initializer_values]
82114

83115
try:
84116
model = _external_data.unload_from_model(
85-
model, base_dir, external_data, size_threshold_bytes=size_threshold_bytes
117+
model,
118+
base_dir,
119+
external_data,
120+
size_threshold_bytes=size_threshold_bytes,
121+
callback=callback,
86122
)
87123
proto = serde.serialize_model(model)
88124
onnx.save(proto, path, format=format)

src/onnx_ir/external_data.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44

55
from __future__ import annotations
66

7+
from typing import Callable
8+
79
__all__ = [
810
"set_base_dir",
911
"unload_from_model",
1012
"load_to_model",
1113
"convert_tensors_to_external",
1214
"convert_tensors_from_external",
15+
"CallbackInfo",
1316
]
1417

1518
import dataclasses
@@ -48,6 +51,21 @@ class _ExternalDataInfo:
4851
length: int
4952

5053

54+
@dataclasses.dataclass
55+
class CallbackInfo:
56+
"""A class that shares information about a tensor that is to be saved as external data for callback functions.
57+
58+
Attributes:
59+
total: The total number of tensors to save.
60+
index: The index of the tensor being saved.
61+
offset: The offset of the tensor in the external data file.
62+
"""
63+
64+
total: int
65+
index: int
66+
offset: int
67+
68+
5169
def _all_tensors(
5270
graph: _core.Graph | _core.GraphView, include_attributes: bool = False
5371
) -> Iterator[_protocols.TensorProtocol]:
@@ -157,19 +175,34 @@ def _write_external_data(
157175
tensors: Sequence[_protocols.TensorProtocol],
158176
external_data_infos: Sequence[_ExternalDataInfo],
159177
file_path: str | os.PathLike,
178+
callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
160179
) -> None:
161180
"""Write tensor data to an external file according to information stored in ExternalDataInfo objects.
162181
163182
Args:
164183
tensors: Tensors to be written as external data.
165184
external_data_infos: External data information stored for each tensor to be written as external data.
166185
file_path: Location to which external data is to be stored.
186+
callback: A callback function that is called for each tensor that is saved to external data
187+
for debugging or logging purposes.
167188
"""
168-
assert len(tensors) == len(external_data_infos), (
189+
tensors_count = len(tensors)
190+
assert tensors_count == len(external_data_infos), (
169191
"Number of tensors and external data infos should match"
170192
)
171193
with open(file_path, "wb") as data_file:
172-
for tensor, tensor_info in zip(tensors, external_data_infos, strict=True):
194+
for i, (tensor, tensor_info) in enumerate(
195+
zip(tensors, external_data_infos, strict=True)
196+
):
197+
if callback is not None:
198+
callback(
199+
tensor,
200+
CallbackInfo(
201+
total=tensors_count,
202+
index=i,
203+
offset=tensor_info.offset,
204+
),
205+
)
173206
current_offset = tensor_info.offset
174207
assert tensor is not None
175208
raw_data = tensor.tobytes()
@@ -228,6 +261,7 @@ def convert_tensors_to_external(
228261
tensors: Sequence[_protocols.TensorProtocol],
229262
base_dir: str | os.PathLike,
230263
relative_path: str | os.PathLike,
264+
callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
231265
) -> list[_core.ExternalTensor]:
232266
"""Convert a sequence of any TensorProtocol tensors to external tensors.
233267
@@ -238,6 +272,8 @@ def convert_tensors_to_external(
238272
tensors: Tensors to be converted to external tensors. They can be external tensors themselves.
239273
base_dir: Path of base directory.
240274
relative_path: Path to which external data is to be stored, relative to the ONNX file.
275+
callback: A callback function that is called for each tensor that is saved to external data
276+
for debugging or logging purposes.
241277
242278
Returns:
243279
A list of external tensors derived from a list of input tensors. The order
@@ -285,7 +321,7 @@ def convert_tensors_to_external(
285321
external_info = _compute_external_data_info(tensor, current_offset)
286322
external_data_infos.append(external_info)
287323
current_offset = external_info.offset + external_info.length
288-
_write_external_data(sorted_tensors, external_data_infos, path)
324+
_write_external_data(sorted_tensors, external_data_infos, path, callback=callback)
289325

290326
# Create external tensor objects
291327
external_tensors: list[_core.ExternalTensor] = [
@@ -336,6 +372,7 @@ def unload_from_model(
336372
relative_path: str | os.PathLike,
337373
*,
338374
size_threshold_bytes: int = 0,
375+
callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
339376
) -> _core.Model:
340377
"""Convert all initializers equal or above size_threshold_bytes to external tensors in-place and save data to a single data file.
341378
@@ -356,6 +393,8 @@ def unload_from_model(
356393
relative_path: Path to which external data is to be stored, relative to the ONNX file.
357394
E.g. "model.data"
358395
size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
396+
callback: A callback function that is called for each tensor that is saved to external data
397+
for debugging or logging purposes.
359398
360399
Returns:
361400
An ir.Model with all initializer data equal or above ``size_threshold_bytes``
@@ -384,6 +423,7 @@ def unload_from_model(
384423
[v.const_value for v in initializers_to_become_external], # type: ignore[misc]
385424
base_dir=base_dir,
386425
relative_path=relative_path,
426+
callback=callback,
387427
)
388428

389429
# Replace the initializer values with external tensors and save the model

0 commit comments

Comments
 (0)