-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce weights sharding #21022
base: master
Are you sure you want to change the base?
Introduce weights sharding #21022
Conversation
6315817
to
e7aa098
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21022 +/- ##
==========================================
+ Coverage 82.45% 82.60% +0.14%
==========================================
Files 562 562
Lines 53720 53834 +114
Branches 8335 8360 +25
==========================================
+ Hits 44297 44467 +170
+ Misses 7381 7291 -90
- Partials 2042 2076 +34
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great to me! @fchollet probably worth looking at this one!
My main question was on the loading side. Do we need to ask people to pass sharded=True
, can we remove the arg or leave it none and infer by default?
if "dtype" in value.attrs and value.attrs["dtype"] == "bfloat16": | ||
value = np.array(value, dtype=ml_dtypes.bfloat16) | ||
return value | ||
def __delitem__(self, key): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this new? Why do we need del?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just added it because it might be helpful. I'm considering making the store object more dict-like, but it's subject to change.
location or instead ask the user via an interactive prompt. | ||
max_shard_size: `int` or `float`. Maximum size in GB for each | ||
sharded file. If `None`, no sharding will be done. Defaults to | ||
`None`. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth adding some code examples, for those who don't want to read and just want to see how to shard :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
I have added an example:
# Instantiate a EfficientNetV2L model with about 454MB of weights.
model = keras.applications.EfficientNetV2L(weights=None)
# Save the weights in a single file.
model.save_weights("model.weights.h5")
# Save the weights in sharded files. Use `max_shard_size=0.25` means
# each sharded file will be at most ~250MB.
model.save_weights("model.weights.json", max_shard_size=0.25)
# Load the weights in a new model with the same architecture.
loaded_model = keras.applications.EfficientNetV2L(weights=None)
loaded_model.load_weights("model.weights.h5")
x = keras.random.uniform((1, 480, 480, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
# Load the sharded weights in a new model with the same architecture.
loaded_model = keras.applications.EfficientNetV2L(weights=None)
loaded_model.load_weights("model.weights.json")
x = keras.random.uniform((1, 480, 480, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of
# each sharded file will be at most ~250MB
model.save_weights("model.weights.json", max_shard_size=0.25)
How about for better readability.
# each sharded file will be at most ~250MB
model.save_weights("model.weights.json", max_shard_size='250MB')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that using a string has better readability but I think @fchollet wanted it to be an int.
Maybe we can support both int and string?
#19286 (comment)
b0dcd6d
to
4503ecd
Compare
Continuing the work based on #19286
This PR introduces
max_shard_size
inModel.save_weights
.Behind the scenes, this PR refactors
H5IOStore
by combining it withH5Entry
. This change allows for more fine-grained control when storing weights. Specifically, it enables the creation of a new shard file once the current shard file reaches its capacity due to incoming weights.Compatibility has been verified, but please let me know if anything was overlooked.
The test for LoRA weights saving/loading in KerasHub has been included.
Ping @mattdangerw and @divyashreepathihalli for requesting this feature.
A simple demo script:
The outputs:
The format of
.weights.json
(similar to HF's format):