Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeed + embeddings fix #254

Closed
wants to merge 29 commits into from
Closed

Deepspeed + embeddings fix #254

wants to merge 29 commits into from

Conversation

anas-awadalla
Copy link
Collaborator

Added support for deepspeed and made training embeddings cleaner using Idefics method.

@@ -87,7 +87,6 @@ def set_input_embeddings(self, new_embeddings):
if decoder_layers_attr_name is None:
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
lang_encoder.resize_token_embeddings(len(text_tokenizer))

model = Flamingo(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the new embeddings method we need to convert the old checkpoints into this format? Or maybe have the new embeds be behind a flag and say something about how this is the desired option for new models?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good question. My vote would be to preserve backward compatibility for now & place the new embeds behind a flag (default set to True)


# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
model.perceiver.requires_grad_(True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For readability: keep this line if the following line is being kept?

Comment on lines 125 to 129
# create a get_output_embeddings() / set_output_embeddings() method if it doesn't exist
# this is needed for compatibility
if not hasattr(self, "get_output_embeddings"):
self.get_output_embeddings = lambda: self.lm_head
self.set_output_embeddings = lambda x: setattr(self, "lm_head", x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move this to factory.py to match how we handle get_input_embeddings for MPT-1B (factory.py:line 73)? To avoid hard-coding for MPT

@@ -277,3 +278,187 @@ def forward(
x = self.ff(x) * self.ff_gate.tanh() + x

return x

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Before we merge, I'll need to test with FSDP; I think this hopefully resolves some of the issues that required freezing the LM embeddings.

@@ -218,6 +233,8 @@ def main():

args = parser.parse_args()

args.local_rank = int(os.environ.get("LOCAL_RANK", -1)) # for deepspeed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, just to understand what's happening here -- deepspeed populates this environment variable, which we load, but then our code overwrites this variable based on Slurm logic on line 272. Are we sure these match?

Comment on lines 368 to 382
# Initialize gradient checkpointing
if args.gradient_checkpointing:
non_reentrant_wrapper = functools.partial(
checkpoint_wrapper,
offload_to_cpu=True,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
ddp_model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
and not isinstance(m, FSDP)
and not isinstance(m, CheckpointWrapper),
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't run this to double check, but I think for non-deepspeed, the checkpointing logic needs to come before initializing the optimizer, or the optimizer may be referring to parameters w/o checkpoint wrapper classes, while the model refers to parameters w/ the wrapper. Could we check this?

"stage3_param_persistence_threshold": 1e4,
"stage3_max_live_parameters": 3e7,
"stage3_prefetch_bucket_size": 3e7,
"stage3_gather_16bit_weights_on_model_save": True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all model weights saved in 16bit, regardless of the precision flag?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This saves additional 16 bit weights is for stage3 but is unnecessary and will be slow! There is a script to 'reconstruct' the fp32 weights from the checkpoint as described here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see -- so since we save ckpts fairly often, should we should set this flag to False by default, and then provide a separate script to all_gather the weights into one file offline? (I guess wandb ckpt saving will need to be turned off)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deepspeed auto creates the script in the checkpoint dir which is nice :). Good point on wandb. For stage 3 yes although we can also just save the sharded checkpoint.

@@ -87,7 +87,6 @@ def set_input_embeddings(self, new_embeddings):
if decoder_layers_attr_name is None:
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
lang_encoder.resize_token_embeddings(len(text_tokenizer))

model = Flamingo(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good question. My vote would be to preserve backward compatibility for now & place the new embeds behind a flag (default set to True)


if args.rank == 0:
if args.report_to_wandb and args.save_checkpoints_to_wandb:
wandb.save(f"{args.run_name}/epoch_{epoch}/mp_rank_00_model_states.pt")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: handle saving stage3 shard state.

@anas-awadalla anas-awadalla changed the base branch from main to mllm September 16, 2023 18:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants