Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: fixed the init_module and deepspeed #20175

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
13 changes: 12 additions & 1 deletion docs/source-fabric/advanced/model_init.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ When training distributed models with :doc:`FSDP/TP <model_parallel/index>` or D

.. code-block:: python

# Recommended for FSDP, TP and DeepSpeed
# Recommended for FSDP and TP
with fabric.init_module(empty_init=True):
model = GPT3() # parameters are placed on the meta-device

Expand All @@ -79,6 +79,17 @@ When training distributed models with :doc:`FSDP/TP <model_parallel/index>` or D
optimizer = torch.optim.Adam(model.parameters())
optimizer = fabric.setup_optimizers(optimizer)

With DeepSpeed Stage 3, the use of :meth:`~lightning.fabric.fabric.Fabric.init_module` context manager is necessary for the model to be sharded correctly instead of attempted to be put on the GPU in its entirety. Deepspeed requires the models and optimizer to be set up jointly.

.. code-block:: python

# Required with DeepSpeed Stage 3
with fabric.init_module(empty_init=True):
model = GPT3()

optimizer = torch.optim.Adam(model.parameters())
model, optimizer = fabric.setup(model, optimizer)

.. note::
Empty-init is experimental and the behavior may change in the future.
For distributed models, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too).
Loading