Skip to content

Commit 6315817

Browse files
committed
Introduce weights sharding
1 parent df2fc37 commit 6315817

File tree

5 files changed

+501
-101
lines changed

5 files changed

+501
-101
lines changed

keras/src/models/model.py

+49-24
Original file line numberDiff line numberDiff line change
@@ -313,47 +313,72 @@ def save(self, filepath, overwrite=True, zipped=None, **kwargs):
313313
)
314314

315315
@traceback_utils.filter_traceback
316-
def save_weights(self, filepath, overwrite=True):
317-
"""Saves all layer weights to a `.weights.h5` file.
316+
def save_weights(self, filepath, overwrite=True, max_shard_size=None):
317+
"""Saves all weights to a single file or sharded files.
318+
319+
By default, the weights will be saved in a single `.weights.h5` file.
320+
If sharding is enabled (`max_shard_size` is not `None`), the weights
321+
will be saved in multiple files, each with a size at most
322+
`max_shard_size` (in GB). Additionally, a configuration file
323+
`.weights.json` will contain the metadata for the sharded files.
318324
319325
Args:
320-
filepath: `str` or `pathlib.Path` object.
321-
Path where to save the model. Must end in `.weights.h5`.
322-
overwrite: Whether we should overwrite any existing model
323-
at the target location, or instead ask the user
324-
via an interactive prompt.
326+
filepath: `str` or `pathlib.Path` object. Path where the weights
327+
will be saved. When sharding, the filepath must end in
328+
`.weights.json`. If `.weights.h5` is provided, it will be
329+
overridden.
330+
overwrite: Whether to overwrite any existing weights at the target
331+
location or instead ask the user via an interactive prompt.
332+
max_shard_size: `int` or `float`. Maximum size in GB for each
333+
sharded file. If `None`, no sharding will be done. Defaults to
334+
`None`.
325335
"""
326-
return saving_api.save_weights(self, filepath, overwrite=overwrite)
336+
return saving_api.save_weights(
337+
self, filepath, overwrite=overwrite, max_shard_size=max_shard_size
338+
)
327339

328340
@traceback_utils.filter_traceback
329-
def load_weights(self, filepath, skip_mismatch=False, **kwargs):
330-
"""Load weights from a file saved via `save_weights()`.
341+
def load_weights(
342+
self, filepath, skip_mismatch=False, sharded=False, **kwargs
343+
):
344+
"""Load the weights from a single file or sharded files.
331345
332-
Weights are loaded based on the network's
333-
topology. This means the architecture should be the same as when the
334-
weights were saved. Note that layers that don't have weights are not
335-
taken into account in the topological ordering, so adding or removing
336-
layers is fine as long as they don't have weights.
346+
Weights are loaded based on the network's topology. This means the
347+
architecture should be the same as when the weights were saved. Note
348+
that layers that don't have weights are not taken into account in the
349+
topological ordering, so adding or removing layers is fine as long as
350+
they don't have weights.
337351
338352
**Partial weight loading**
339353
340354
If you have modified your model, for instance by adding a new layer
341-
(with weights) or by changing the shape of the weights of a layer,
342-
you can choose to ignore errors and continue loading
343-
by setting `skip_mismatch=True`. In this case any layer with
344-
mismatching weights will be skipped. A warning will be displayed
345-
for each skipped layer.
355+
(with weights) or by changing the shape of the weights of a layer, you
356+
can choose to ignore errors and continue loading by setting
357+
`skip_mismatch=True`. In this case any layer with mismatching weights
358+
will be skipped. A warning will be displayed for each skipped layer.
359+
360+
**Sharding**
361+
362+
When loading sharded weights, it is important to set `sharded=True` and
363+
specify `filepath` that ends with `.weights.json`.
346364
347365
Args:
348-
filepath: String, path to the weights file to load.
349-
It can either be a `.weights.h5` file
350-
or a legacy `.h5` weights file.
366+
filepath: `str` or `pathlib.Path` object. Path where the weights
367+
will be saved. When sharding, the filepath must end in
368+
`.weights.json`. If `.weights.h5` is provided, it will be
369+
overridden.
351370
skip_mismatch: Boolean, whether to skip loading of layers where
352371
there is a mismatch in the number of weights, or a mismatch in
353372
the shape of the weights.
373+
sharded: Whether the saved file(s) are sharded. Defaults to
374+
`False`.
354375
"""
355376
saving_api.load_weights(
356-
self, filepath, skip_mismatch=skip_mismatch, **kwargs
377+
self,
378+
filepath,
379+
skip_mismatch=skip_mismatch,
380+
sharded=sharded,
381+
**kwargs,
357382
)
358383

359384
def quantize(self, mode, **kwargs):

keras/src/saving/saving_api.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -219,32 +219,45 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
219219

220220

221221
@keras_export("keras.saving.save_weights")
222-
def save_weights(model, filepath, overwrite=True, **kwargs):
223-
if not str(filepath).endswith(".weights.h5"):
222+
def save_weights(
223+
model, filepath, overwrite=True, max_shard_size=None, **kwargs
224+
):
225+
filepath_str = str(filepath)
226+
if max_shard_size is None and not filepath_str.endswith(".weights.h5"):
224227
raise ValueError(
225228
"The filename must end in `.weights.h5`. "
226-
f"Received: filepath={filepath}"
229+
f"Received: filepath={filepath_str}"
230+
)
231+
elif max_shard_size is not None and not filepath_str.endswith(
232+
("weights.h5", "weights.json")
233+
):
234+
raise ValueError(
235+
"The filename must end in `.weights.json` when `max_shard_size` is "
236+
f"specified. Received: filepath={filepath_str}"
227237
)
228238
try:
229239
exists = os.path.exists(filepath)
230240
except TypeError:
231241
exists = False
232242
if exists and not overwrite:
233-
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
243+
proceed = io_utils.ask_to_proceed_with_overwrite(filepath_str)
234244
if not proceed:
235245
return
236-
saving_lib.save_weights_only(model, filepath, **kwargs)
246+
saving_lib.save_weights_only(model, filepath, max_shard_size, **kwargs)
237247

238248

239249
@keras_export("keras.saving.load_weights")
240-
def load_weights(model, filepath, skip_mismatch=False, **kwargs):
241-
if str(filepath).endswith(".keras"):
250+
def load_weights(model, filepath, skip_mismatch=False, sharded=False, **kwargs):
251+
filepath_str = str(filepath)
252+
if filepath_str.endswith(".keras"):
242253
if kwargs:
243254
raise ValueError(f"Invalid keyword arguments: {kwargs}")
244255
saving_lib.load_weights_only(
245-
model, filepath, skip_mismatch=skip_mismatch
256+
model, filepath, skip_mismatch=skip_mismatch, sharded=sharded
246257
)
247-
elif str(filepath).endswith(".weights.h5"):
258+
elif filepath_str.endswith(".weights.h5") or filepath_str.endswith(
259+
".weights.json"
260+
):
248261
objects_to_skip = kwargs.pop("objects_to_skip", None)
249262
if kwargs:
250263
raise ValueError(f"Invalid keyword arguments: {kwargs}")
@@ -253,8 +266,9 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
253266
filepath,
254267
skip_mismatch=skip_mismatch,
255268
objects_to_skip=objects_to_skip,
269+
sharded=sharded,
256270
)
257-
elif str(filepath).endswith(".h5") or str(filepath).endswith(".hdf5"):
271+
elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"):
258272
by_name = kwargs.pop("by_name", False)
259273
if kwargs:
260274
raise ValueError(f"Invalid keyword arguments: {kwargs}")

keras/src/saving/saving_api_test.py

+15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import pathlib
23
import unittest.mock as mock
34

45
import numpy as np
@@ -237,6 +238,20 @@ def test_load_weights_invalid_extension(self):
237238
with self.assertRaisesRegex(ValueError, "File format not supported"):
238239
model.load_weights("invalid_extension.pkl")
239240

241+
def test_load_sharded_weights(self):
242+
src_model = self.get_model()
243+
temp_filepath = pathlib.Path(
244+
os.path.join(self.get_temp_dir(), "test_weights.weights.json")
245+
)
246+
src_model.save_weights(temp_filepath, max_shard_size=1)
247+
self.assertLen(os.listdir(temp_filepath.parent), 2)
248+
src_weights = src_model.get_weights()
249+
dest_model = self.get_model()
250+
dest_model.load_weights(temp_filepath, sharded=True)
251+
dest_weights = dest_model.get_weights()
252+
for orig, loaded in zip(src_weights, dest_weights):
253+
self.assertAllClose(orig, loaded)
254+
240255

241256
class SaveModelTestsWarning(test_case.TestCase):
242257
def get_model(self):

0 commit comments

Comments
 (0)