Skip to content

Commit b0dcd6d

Browse files
committed
Update docstring
1 parent 2c00862 commit b0dcd6d

File tree

1 file changed

+7
-16
lines changed

1 file changed

+7
-16
lines changed

keras/src/models/model.py

+7-16
Original file line numberDiff line numberDiff line change
@@ -336,33 +336,24 @@ def save_weights(self, filepath, overwrite=True, max_shard_size=None):
336336
Example:
337337
338338
```python
339-
# Create a big model with about 272MB of weights.
340-
model = keras.Sequential()
341-
model.add(keras.Input(shape=(1024,)))
342-
for _ in range(5):
343-
model.add(keras.layers.Dense(4096))
339+
# Instantiate a EfficientNetV2L model with about 454MB of weights.
340+
model = keras.applications.EfficientNetV2L(weights=None)
344341
345342
# Save the weights in a single file.
346343
model.save_weights("model.weights.h5")
347344
348-
# Save the weights in sharded files. Use `max_shard_size=0.1` means each
349-
# sharded file will be at most ~100MB.
350-
model.save_weights("model.weights.json", max_shard_size=0.1)
345+
# Save the weights in sharded files. Use `max_shard_size=0.25` means
346+
# each sharded file will be at most ~250MB.
347+
model.save_weights("model.weights.json", max_shard_size=0.25)
351348
352349
# Load the weights in a new model with the same architecture.
353-
loaded_model = keras.Sequential()
354-
loaded_model.add(keras.Input(shape=(1024,)))
355-
for _ in range(5):
356-
loaded_model.add(keras.layers.Dense(4096))
350+
loaded_model = keras.applications.EfficientNetV2L(weights=None)
357351
loaded_model.load_weights("model.weights.h5")
358352
x = keras.random.uniform((1, 1024))
359353
assert np.allclose(model.predict(x), loaded_model.predict(x))
360354
361355
# Load the sharded weights in a new model with the same architecture.
362-
loaded_model = keras.Sequential()
363-
loaded_model.add(keras.Input(shape=(1024,)))
364-
for _ in range(5):
365-
loaded_model.add(keras.layers.Dense(4096))
356+
loaded_model = keras.applications.EfficientNetV2L(weights=None)
366357
loaded_model.load_weights("model.weights.json")
367358
x = keras.random.uniform((1, 1024))
368359
assert np.allclose(model.predict(x), loaded_model.predict(x))

0 commit comments

Comments
 (0)