-
Notifications
You must be signed in to change notification settings - Fork 284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deepspeed + embeddings fix #254
Conversation
@@ -87,7 +87,6 @@ def set_input_embeddings(self, new_embeddings): | |||
if decoder_layers_attr_name is None: | |||
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) | |||
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) | |||
lang_encoder.resize_token_embeddings(len(text_tokenizer)) | |||
|
|||
model = Flamingo( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the new embeddings method we need to convert the old checkpoints into this format? Or maybe have the new embeds be behind a flag and say something about how this is the desired option for new models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good question. My vote would be to preserve backward compatibility for now & place the new embeds behind a flag (default set to True)
|
||
# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings | ||
model.perceiver.requires_grad_(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For readability: keep this line if the following line is being kept?
open_flamingo/src/flamingo_lm.py
Outdated
# create a get_output_embeddings() / set_output_embeddings() method if it doesn't exist | ||
# this is needed for compatibility | ||
if not hasattr(self, "get_output_embeddings"): | ||
self.get_output_embeddings = lambda: self.lm_head | ||
self.set_output_embeddings = lambda x: setattr(self, "lm_head", x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we move this to factory.py to match how we handle get_input_embeddings for MPT-1B (factory.py:line 73)? To avoid hard-coding for MPT
@@ -277,3 +278,187 @@ def forward( | |||
x = self.ff(x) * self.ff_gate.tanh() + x | |||
|
|||
return x | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! Before we merge, I'll need to test with FSDP; I think this hopefully resolves some of the issues that required freezing the LM embeddings.
open_flamingo/train/train.py
Outdated
@@ -218,6 +233,8 @@ def main(): | |||
|
|||
args = parser.parse_args() | |||
|
|||
args.local_rank = int(os.environ.get("LOCAL_RANK", -1)) # for deepspeed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, just to understand what's happening here -- deepspeed populates this environment variable, which we load, but then our code overwrites this variable based on Slurm logic on line 272. Are we sure these match?
open_flamingo/train/train.py
Outdated
# Initialize gradient checkpointing | ||
if args.gradient_checkpointing: | ||
non_reentrant_wrapper = functools.partial( | ||
checkpoint_wrapper, | ||
offload_to_cpu=True, | ||
checkpoint_impl=CheckpointImpl.NO_REENTRANT, | ||
) | ||
apply_activation_checkpointing( | ||
ddp_model, | ||
checkpoint_wrapper_fn=non_reentrant_wrapper, | ||
check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False) | ||
and not isinstance(m, FSDP) | ||
and not isinstance(m, CheckpointWrapper), | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haven't run this to double check, but I think for non-deepspeed, the checkpointing logic needs to come before initializing the optimizer, or the optimizer may be referring to parameters w/o checkpoint wrapper classes, while the model refers to parameters w/ the wrapper. Could we check this?
open_flamingo/train/train.py
Outdated
"stage3_param_persistence_threshold": 1e4, | ||
"stage3_max_live_parameters": 3e7, | ||
"stage3_prefetch_bucket_size": 3e7, | ||
"stage3_gather_16bit_weights_on_model_save": True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are all model weights saved in 16bit, regardless of the precision flag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This saves additional 16 bit weights is for stage3 but is unnecessary and will be slow! There is a script to 'reconstruct' the fp32 weights from the checkpoint as described here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see -- so since we save ckpts fairly often, should we should set this flag to False by default, and then provide a separate script to all_gather the weights into one file offline? (I guess wandb ckpt saving will need to be turned off)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deepspeed auto creates the script in the checkpoint dir which is nice :). Good point on wandb. For stage 3 yes although we can also just save the sharded checkpoint.
@@ -87,7 +87,6 @@ def set_input_embeddings(self, new_embeddings): | |||
if decoder_layers_attr_name is None: | |||
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) | |||
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) | |||
lang_encoder.resize_token_embeddings(len(text_tokenizer)) | |||
|
|||
model = Flamingo( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good question. My vote would be to preserve backward compatibility for now & place the new embeds behind a flag (default set to True)
open_flamingo/train/train_utils.py
Outdated
|
||
if args.rank == 0: | ||
if args.report_to_wandb and args.save_checkpoints_to_wandb: | ||
wandb.save(f"{args.run_name}/epoch_{epoch}/mp_rank_00_model_states.pt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: handle saving stage3 shard state.
Added support for deepspeed and made training embeddings cleaner using Idefics method.