Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce weights sharding #21022

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Mar 13, 2025

Continuing the work based on #19286

This PR introduces max_shard_size in Model.save_weights.

Behind the scenes, this PR refactors H5IOStore by combining it with H5Entry. This change allows for more fine-grained control when storing weights. Specifically, it enables the creation of a new shard file once the current shard file reaches its capacity due to incoming weights.

Compatibility has been verified, but please let me know if anything was overlooked.
The test for LoRA weights saving/loading in KerasHub has been included.

Ping @mattdangerw and @divyashreepathihalli for requesting this feature.

A simple demo script:

import os

import numpy as np

from keras import applications

# EfficientNetB0 is about 20.3MB.
model = applications.EfficientNetB0(weights=None, input_shape=(224, 224, 3))
ref_input = np.random.random((1, 224, 224, 3)).astype("float32")
ref_output = model.predict(ref_input)

# `max_shard_size` is in GB. 0.015 means about 15MB per shard.
model.save_weights("model.weights.json", max_shard_size=0.015)
files = [x for x in os.listdir(".")]
assert "model.weights.json" in files
assert "model_00000.weighs.h5" in files
assert "model_00001.weighs.h5" in files
print("Sharded weights saved successfully!")

# Load the sharded weights with the new instance.
model = applications.EfficientNetB0(weights=None, input_shape=(224, 224, 3))
model.load_weights("model.weights.json", sharded=True)
np.testing.assert_allclose(model.predict(ref_input), ref_output, atol=1e-6)
print("Passed!")

The outputs:

1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step
Sharded weights saved successfully!
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 996ms/step
Passed!

The format of .weights.json (similar to HF's format):

{
    "metadata": {
        "total_size": 476111392.0
    },
    "weight_map": {
        "/vars": "model_00000.weighs.h5",
        "/layers/input_layer/vars": "model_00000.weighs.h5",
        "/layers/rescaling/vars": "model_00000.weighs.h5",
        "/layers/conv2d/vars": "model_00000.weighs.h5",
        "/layers/batch_normalization/vars": "model_00000.weighs.h5",
...

@codecov-commenter
Copy link

codecov-commenter commented Mar 13, 2025

Codecov Report

Attention: Patch coverage is 74.28571% with 45 lines in your changes missing coverage. Please review.

Project coverage is 82.60%. Comparing base (decd6ba) to head (4503ecd).

Files with missing lines Patch % Lines
keras/src/saving/saving_lib.py 74.69% 25 Missing and 16 partials ⚠️
keras/src/saving/saving_api.py 63.63% 2 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21022      +/-   ##
==========================================
+ Coverage   82.45%   82.60%   +0.14%     
==========================================
  Files         562      562              
  Lines       53720    53834     +114     
  Branches     8335     8360      +25     
==========================================
+ Hits        44297    44467     +170     
+ Misses       7381     7291      -90     
- Partials     2042     2076      +34     
Flag Coverage Δ
keras 82.41% <74.28%> (+0.13%) ⬆️
keras-jax 63.82% <74.28%> (+0.16%) ⬆️
keras-numpy 58.84% <73.71%> (+0.22%) ⬆️
keras-openvino 32.69% <11.42%> (-0.05%) ⬇️
keras-tensorflow 64.25% <74.28%> (+0.16%) ⬆️
keras-torch 63.85% <74.28%> (+0.17%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great to me! @fchollet probably worth looking at this one!

My main question was on the loading side. Do we need to ask people to pass sharded=True, can we remove the arg or leave it none and infer by default?

if "dtype" in value.attrs and value.attrs["dtype"] == "bfloat16":
value = np.array(value, dtype=ml_dtypes.bfloat16)
return value
def __delitem__(self, key):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this new? Why do we need del?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added it because it might be helpful. I'm considering making the store object more dict-like, but it's subject to change.

location or instead ask the user via an interactive prompt.
max_shard_size: `int` or `float`. Maximum size in GB for each
sharded file. If `None`, no sharding will be done. Defaults to
`None`.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth adding some code examples, for those who don't want to read and just want to see how to shard :)

Copy link
Contributor Author

@james77777778 james77777778 Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

I have added an example:

# Instantiate a EfficientNetV2L model with about 454MB of weights.
model = keras.applications.EfficientNetV2L(weights=None)

# Save the weights in a single file.
model.save_weights("model.weights.h5")

# Save the weights in sharded files. Use `max_shard_size=0.25` means
# each sharded file will be at most ~250MB.
model.save_weights("model.weights.json", max_shard_size=0.25)

# Load the weights in a new model with the same architecture.
loaded_model = keras.applications.EfficientNetV2L(weights=None)
loaded_model.load_weights("model.weights.h5")
x = keras.random.uniform((1, 480, 480, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))

# Load the sharded weights in a new model with the same architecture.
loaded_model = keras.applications.EfficientNetV2L(weights=None)
loaded_model.load_weights("model.weights.json")
x = keras.random.uniform((1, 480, 480, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of

# each sharded file will be at most ~250MB
model.save_weights("model.weights.json", max_shard_size=0.25)

How about for better readability.

# each sharded file will be at most ~250MB
model.save_weights("model.weights.json", max_shard_size='250MB')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that using a string has better readability but I think @fchollet wanted it to be an int.
Maybe we can support both int and string?
#19286 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants