Skip to content

Commit 4503ecd

Browse files
committed
Update docstring
1 parent 2c00862 commit 4503ecd

File tree

1 file changed

+9
-18
lines changed

1 file changed

+9
-18
lines changed

keras/src/models/model.py

+9-18
Original file line numberDiff line numberDiff line change
@@ -336,35 +336,26 @@ def save_weights(self, filepath, overwrite=True, max_shard_size=None):
336336
Example:
337337
338338
```python
339-
# Create a big model with about 272MB of weights.
340-
model = keras.Sequential()
341-
model.add(keras.Input(shape=(1024,)))
342-
for _ in range(5):
343-
model.add(keras.layers.Dense(4096))
339+
# Instantiate a EfficientNetV2L model with about 454MB of weights.
340+
model = keras.applications.EfficientNetV2L(weights=None)
344341
345342
# Save the weights in a single file.
346343
model.save_weights("model.weights.h5")
347344
348-
# Save the weights in sharded files. Use `max_shard_size=0.1` means each
349-
# sharded file will be at most ~100MB.
350-
model.save_weights("model.weights.json", max_shard_size=0.1)
345+
# Save the weights in sharded files. Use `max_shard_size=0.25` means
346+
# each sharded file will be at most ~250MB.
347+
model.save_weights("model.weights.json", max_shard_size=0.25)
351348
352349
# Load the weights in a new model with the same architecture.
353-
loaded_model = keras.Sequential()
354-
loaded_model.add(keras.Input(shape=(1024,)))
355-
for _ in range(5):
356-
loaded_model.add(keras.layers.Dense(4096))
350+
loaded_model = keras.applications.EfficientNetV2L(weights=None)
357351
loaded_model.load_weights("model.weights.h5")
358-
x = keras.random.uniform((1, 1024))
352+
x = keras.random.uniform((1, 480, 480, 3))
359353
assert np.allclose(model.predict(x), loaded_model.predict(x))
360354
361355
# Load the sharded weights in a new model with the same architecture.
362-
loaded_model = keras.Sequential()
363-
loaded_model.add(keras.Input(shape=(1024,)))
364-
for _ in range(5):
365-
loaded_model.add(keras.layers.Dense(4096))
356+
loaded_model = keras.applications.EfficientNetV2L(weights=None)
366357
loaded_model.load_weights("model.weights.json")
367-
x = keras.random.uniform((1, 1024))
358+
x = keras.random.uniform((1, 480, 480, 3))
368359
assert np.allclose(model.predict(x), loaded_model.predict(x))
369360
```
370361

0 commit comments

Comments
 (0)