@@ -313,47 +313,72 @@ def save(self, filepath, overwrite=True, zipped=None, **kwargs):
313
313
)
314
314
315
315
@traceback_utils .filter_traceback
316
- def save_weights (self , filepath , overwrite = True ):
317
- """Saves all layer weights to a `.weights.h5` file.
316
+ def save_weights (self , filepath , overwrite = True , max_shard_size = None ):
317
+ """Saves all weights to a single file or sharded files.
318
+
319
+ By default, the weights will be saved in a single `.weights.h5` file.
320
+ If sharding is enabled (`max_shard_size` is not `None`), the weights
321
+ will be saved in multiple files, each with a size at most
322
+ `max_shard_size` (in GB). Additionally, a configuration file
323
+ `.weights.json` will contain the metadata for the sharded files.
318
324
319
325
Args:
320
- filepath: `str` or `pathlib.Path` object.
321
- Path where to save the model. Must end in `.weights.h5`.
322
- overwrite: Whether we should overwrite any existing model
323
- at the target location, or instead ask the user
324
- via an interactive prompt.
326
+ filepath: `str` or `pathlib.Path` object. Path where the weights
327
+ will be saved. When sharding, the filepath must end in
328
+ `.weights.json`. If `.weights.h5` is provided, it will be
329
+ overridden.
330
+ overwrite: Whether to overwrite any existing weights at the target
331
+ location or instead ask the user via an interactive prompt.
332
+ max_shard_size: `int` or `float`. Maximum size in GB for each
333
+ sharded file. If `None`, no sharding will be done. Defaults to
334
+ `None`.
325
335
"""
326
- return saving_api .save_weights (self , filepath , overwrite = overwrite )
336
+ return saving_api .save_weights (
337
+ self , filepath , overwrite = overwrite , max_shard_size = max_shard_size
338
+ )
327
339
328
340
@traceback_utils .filter_traceback
329
- def load_weights (self , filepath , skip_mismatch = False , ** kwargs ):
330
- """Load weights from a file saved via `save_weights()`.
341
+ def load_weights (
342
+ self , filepath , skip_mismatch = False , sharded = False , ** kwargs
343
+ ):
344
+ """Load the weights from a single file or sharded files.
331
345
332
- Weights are loaded based on the network's
333
- topology. This means the architecture should be the same as when the
334
- weights were saved. Note that layers that don't have weights are not
335
- taken into account in the topological ordering, so adding or removing
336
- layers is fine as long as they don't have weights.
346
+ Weights are loaded based on the network's topology. This means the
347
+ architecture should be the same as when the weights were saved. Note
348
+ that layers that don't have weights are not taken into account in the
349
+ topological ordering, so adding or removing layers is fine as long as
350
+ they don't have weights.
337
351
338
352
**Partial weight loading**
339
353
340
354
If you have modified your model, for instance by adding a new layer
341
- (with weights) or by changing the shape of the weights of a layer,
342
- you can choose to ignore errors and continue loading
343
- by setting `skip_mismatch=True`. In this case any layer with
344
- mismatching weights will be skipped. A warning will be displayed
345
- for each skipped layer.
355
+ (with weights) or by changing the shape of the weights of a layer, you
356
+ can choose to ignore errors and continue loading by setting
357
+ `skip_mismatch=True`. In this case any layer with mismatching weights
358
+ will be skipped. A warning will be displayed for each skipped layer.
359
+
360
+ **Sharding**
361
+
362
+ When loading sharded weights, it is important to set `sharded=True` and
363
+ specify `filepath` that ends with `.weights.json`.
346
364
347
365
Args:
348
- filepath: String, path to the weights file to load.
349
- It can either be a `.weights.h5` file
350
- or a legacy `.h5` weights file.
366
+ filepath: `str` or `pathlib.Path` object. Path where the weights
367
+ will be saved. When sharding, the filepath must end in
368
+ `.weights.json`. If `.weights.h5` is provided, it will be
369
+ overridden.
351
370
skip_mismatch: Boolean, whether to skip loading of layers where
352
371
there is a mismatch in the number of weights, or a mismatch in
353
372
the shape of the weights.
373
+ sharded: Whether the saved file(s) are sharded. Defaults to
374
+ `False`.
354
375
"""
355
376
saving_api .load_weights (
356
- self , filepath , skip_mismatch = skip_mismatch , ** kwargs
377
+ self ,
378
+ filepath ,
379
+ skip_mismatch = skip_mismatch ,
380
+ sharded = sharded ,
381
+ ** kwargs ,
357
382
)
358
383
359
384
def quantize (self , mode , ** kwargs ):
0 commit comments