4
4
5
5
from __future__ import annotations
6
6
7
+ from typing import Callable
8
+
7
9
__all__ = [
8
10
"set_base_dir" ,
9
11
"unload_from_model" ,
10
12
"load_to_model" ,
11
13
"convert_tensors_to_external" ,
12
14
"convert_tensors_from_external" ,
15
+ "CallbackInfo" ,
13
16
]
14
17
15
18
import dataclasses
@@ -48,6 +51,21 @@ class _ExternalDataInfo:
48
51
length : int
49
52
50
53
54
+ @dataclasses .dataclass
55
+ class CallbackInfo :
56
+ """A class that shares information about a tensor that is to be saved as external data for callback functions.
57
+
58
+ Attributes:
59
+ total: The total number of tensors to save.
60
+ index: The index of the tensor being saved.
61
+ offset: The offset of the tensor in the external data file.
62
+ """
63
+
64
+ total : int
65
+ index : int
66
+ offset : int
67
+
68
+
51
69
def _all_tensors (
52
70
graph : _core .Graph | _core .GraphView , include_attributes : bool = False
53
71
) -> Iterator [_protocols .TensorProtocol ]:
@@ -157,19 +175,34 @@ def _write_external_data(
157
175
tensors : Sequence [_protocols .TensorProtocol ],
158
176
external_data_infos : Sequence [_ExternalDataInfo ],
159
177
file_path : str | os .PathLike ,
178
+ callback : Callable [[_protocols .TensorProtocol , CallbackInfo ], None ] | None = None ,
160
179
) -> None :
161
180
"""Write tensor data to an external file according to information stored in ExternalDataInfo objects.
162
181
163
182
Args:
164
183
tensors: Tensors to be written as external data.
165
184
external_data_infos: External data information stored for each tensor to be written as external data.
166
185
file_path: Location to which external data is to be stored.
186
+ callback: A callback function that is called for each tensor that is saved to external data
187
+ for debugging or logging purposes.
167
188
"""
168
- assert len (tensors ) == len (external_data_infos ), (
189
+ tensors_count = len (tensors )
190
+ assert tensors_count == len (external_data_infos ), (
169
191
"Number of tensors and external data infos should match"
170
192
)
171
193
with open (file_path , "wb" ) as data_file :
172
- for tensor , tensor_info in zip (tensors , external_data_infos , strict = True ):
194
+ for i , (tensor , tensor_info ) in enumerate (
195
+ zip (tensors , external_data_infos , strict = True )
196
+ ):
197
+ if callback is not None :
198
+ callback (
199
+ tensor ,
200
+ CallbackInfo (
201
+ total = tensors_count ,
202
+ index = i ,
203
+ offset = tensor_info .offset ,
204
+ ),
205
+ )
173
206
current_offset = tensor_info .offset
174
207
assert tensor is not None
175
208
raw_data = tensor .tobytes ()
@@ -228,6 +261,7 @@ def convert_tensors_to_external(
228
261
tensors : Sequence [_protocols .TensorProtocol ],
229
262
base_dir : str | os .PathLike ,
230
263
relative_path : str | os .PathLike ,
264
+ callback : Callable [[_protocols .TensorProtocol , CallbackInfo ], None ] | None = None ,
231
265
) -> list [_core .ExternalTensor ]:
232
266
"""Convert a sequence of any TensorProtocol tensors to external tensors.
233
267
@@ -238,6 +272,8 @@ def convert_tensors_to_external(
238
272
tensors: Tensors to be converted to external tensors. They can be external tensors themselves.
239
273
base_dir: Path of base directory.
240
274
relative_path: Path to which external data is to be stored, relative to the ONNX file.
275
+ callback: A callback function that is called for each tensor that is saved to external data
276
+ for debugging or logging purposes.
241
277
242
278
Returns:
243
279
A list of external tensors derived from a list of input tensors. The order
@@ -285,7 +321,7 @@ def convert_tensors_to_external(
285
321
external_info = _compute_external_data_info (tensor , current_offset )
286
322
external_data_infos .append (external_info )
287
323
current_offset = external_info .offset + external_info .length
288
- _write_external_data (sorted_tensors , external_data_infos , path )
324
+ _write_external_data (sorted_tensors , external_data_infos , path , callback = callback )
289
325
290
326
# Create external tensor objects
291
327
external_tensors : list [_core .ExternalTensor ] = [
@@ -336,6 +372,7 @@ def unload_from_model(
336
372
relative_path : str | os .PathLike ,
337
373
* ,
338
374
size_threshold_bytes : int = 0 ,
375
+ callback : Callable [[_protocols .TensorProtocol , CallbackInfo ], None ] | None = None ,
339
376
) -> _core .Model :
340
377
"""Convert all initializers equal or above size_threshold_bytes to external tensors in-place and save data to a single data file.
341
378
@@ -356,6 +393,8 @@ def unload_from_model(
356
393
relative_path: Path to which external data is to be stored, relative to the ONNX file.
357
394
E.g. "model.data"
358
395
size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
396
+ callback: A callback function that is called for each tensor that is saved to external data
397
+ for debugging or logging purposes.
359
398
360
399
Returns:
361
400
An ir.Model with all initializer data equal or above ``size_threshold_bytes``
@@ -384,6 +423,7 @@ def unload_from_model(
384
423
[v .const_value for v in initializers_to_become_external ], # type: ignore[misc]
385
424
base_dir = base_dir ,
386
425
relative_path = relative_path ,
426
+ callback = callback ,
387
427
)
388
428
389
429
# Replace the initializer values with external tensors and save the model
0 commit comments