Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeed + embeddings fix #254

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0bfba86
add other models, deepspeed attempt
i-gao Aug 29, 2023
ff0bdf5
fixed embedding freezing using Idefics among other things
anas-awadalla Aug 31, 2023
f7b45a4
Reverted specified files to the version in origin/main
anas-awadalla Aug 31, 2023
176bbc0
added deepspeed to reqs
anas-awadalla Aug 31, 2023
71ddca6
remove stage3 16 bit weights and local_rank arg
anas-awadalla Aug 31, 2023
76f4f8c
move get_embed func to factory.py and restore req grad call
anas-awadalla Aug 31, 2023
400fffc
fix bias check
Sep 1, 2023
bb01717
another bias fix
Sep 1, 2023
104975c
remove trust_remote_code as mpt is part of transformers now
anas-awadalla Sep 1, 2023
17babfe
add lm_head check
anas-awadalla Sep 2, 2023
10afa74
more changes
anas-awadalla Sep 2, 2023
a248c26
Update factory.py
anas-awadalla Sep 2, 2023
c453ca8
init on device to avoid cpu oom
i-gao Sep 4, 2023
de187aa
update eval script to use deepspeed
i-gao Sep 4, 2023
6f62054
restore pad_to_multiple_of kwarg in factory
i-gao Sep 4, 2023
e626afb
fixed embed not training issue
anas-awadalla Sep 5, 2023
b91da53
Merge branch 'deepspeed' of https://github.com/mlfoundations/open_fla…
anas-awadalla Sep 5, 2023
ecc74ad
embed training
Sep 5, 2023
3230715
fix merge conflict
Sep 5, 2023
fe39d1a
tie decoupled embeddings
anas-awadalla Sep 5, 2023
8bd7273
untie embeds for fsdp
anas-awadalla Sep 5, 2023
d10a998
move grad checkpointing before optimizer creation
anas-awadalla Sep 5, 2023
4f3ce24
default is not to untie
anas-awadalla Sep 5, 2023
2c5d864
fix embed init
anas-awadalla Sep 5, 2023
8ff8d5e
Update factory.py
anas-awadalla Sep 5, 2023
d75fecd
fix embed init and out embed concat
anas-awadalla Sep 9, 2023
3805d0f
Merge branch 'deepspeed' of https://github.com/mlfoundations/open_fla…
anas-awadalla Sep 9, 2023
27ce458
Merge branch 'deepspeed' into deepspeed_inference
anas-awadalla Sep 15, 2023
c2e95b4
Merge pull request #255 from mlfoundations/deepspeed_inference
anas-awadalla Sep 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 12 additions & 22 deletions open_flamingo/scripts/run_train.sh
Original file line number Diff line number Diff line change
@@ -1,42 +1,32 @@
#!/bin/bash
#SBATCH --nodes 1
#SBATCH --ntasks-per-node=6
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-task=1
#SBATCH --account=efml
#SBATCH --partition=gpu
#SBATCH --time=48:00:00
#SBATCH --job-name=flamingo

export PYTHONFAULTHANDLER=1
export CUDA_LAUNCH_BLOCKING=0
export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_PORT=15000
export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`
export HF_DATASETS_CACHE="/gscratch/efml/anasa2/.huggingface" TRANSFORMERS_CACHE="/gscratch/efml/anasa2/.huggingface"

export PYTHONPATH="$PYTHONPATH:open_flamingo"
srun --cpu_bind=v --accel-bind=gn python



deepspeed open_flamingo/open_flamingo/train/train.py \
--lm_path meta-llama/Llama-2-13b \
--tokenizer_path meta-llama/Llama-2-13b \
--cross_attn_every_n_layers 4 \
srun --cpu_bind=v --accel-bind=gn python open_flamingo/open_flamingo/train/train.py \
--lm_path anas-awadalla/mpt-1b-redpajama-200b \
--tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \
--cross_attn_every_n_layers 1 \
--dataset_resampled \
--batch_size_mmc4 16 \
--batch_size_laion 32 \
--deepspeed \
--batch_size_mmc4 32 \
--batch_size_laion 64 \
--train_num_samples_mmc4 125000\
--train_num_samples_laion 250000 \
--loss_multiplier_laion 0.2 \
--workers=4 \
--run_name "deepspeed" \
--run_name OpenFlamingo-3B-vitl-mpt1b \
--num_epochs 480 \
--warmup_steps 0 \
--mmc4_textsim_threshold 0.0 \
--laion_shards "/mmfs1/gscratch/efml/anasa2/laion-samples/{000000..000001}.tar" \
--mmc4_shards "/mmfs1/gscratch/efml/anasa2/mmc4-samples/shard_{0..1}-000000000.tar" \
--warmup_steps 1875 \
--mmc4_textsim_threshold 0.24 \
--laion_shards "/path/to/shards/shard-{0000..0999}.tar" \
--mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \
--gradient_checkpointing \
--report_to_wandb \
22 changes: 10 additions & 12 deletions open_flamingo/src/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def create_model_and_transforms(
cross_attn_every_n_layers: int = 1,
use_local_files: bool = False,
decoder_layers_attr_name: str = None,
freeze_lm_embeddings: bool = False,
cache_dir: Optional[str] = None,
**flamingo_kwargs,
):
Expand All @@ -32,7 +31,6 @@ def create_model_and_transforms(
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
use_local_files (bool, optional): whether to use local files. Defaults to False.
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver.
cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
Returns:
Flamingo: Flamingo model from pretrained vision and language encoders
Expand All @@ -57,10 +55,12 @@ def create_model_and_transforms(
text_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<|endofchunk|>", "<image>"]}
)
new_tokens = 2
if text_tokenizer.pad_token is None:
# Issue: GPT models don't have a pad token, which we use to
# modify labels for the loss.
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
new_tokens += 1

lang_encoder = AutoModelForCausalLM.from_pretrained(
lang_encoder_path,
Expand All @@ -80,16 +80,17 @@ def set_input_embeddings(self, new_embeddings):
self.transformer.wte = new_embeddings

extend_instance(lang_encoder, EmbeddingFnMixin)

if not hasattr(lang_encoder, "get_output_embeddings"):
lang_encoder.get_output_embeddings = lambda: lang_encoder.lm_head
lang_encoder.set_output_embeddings = lambda x: setattr(lang_encoder, "lm_head", x)

# convert LM to FlamingoLM
extend_instance(lang_encoder, FlamingoLMMixin)

if decoder_layers_attr_name is None:
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
lang_encoder.resize_token_embeddings(
len(text_tokenizer)
)

model = Flamingo(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the new embeddings method we need to convert the old checkpoints into this format? Or maybe have the new embeds be behind a flag and say something about how this is the desired option for new models?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good question. My vote would be to preserve backward compatibility for now & place the new embeds behind a flag (default set to True)

vision_encoder,
Expand All @@ -100,21 +101,18 @@ def set_input_embeddings(self, new_embeddings):
"width"
],
cross_attn_every_n_layers=cross_attn_every_n_layers,
new_tokens=new_tokens, # number of tokens embeddings to train
**flamingo_kwargs,
)

# Freeze all parameters
model.requires_grad_(False)
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
model.vision_encoder.requires_grad_(False)
model.lang_encoder.requires_grad_(False)

# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
# Unfreeze gated_cross_attn_layers and perceiver
model.perceiver.requires_grad_(True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For readability: keep this line if the following line is being kept?

model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)

# TODO: FIX this. Currently we are just training all embeddings unless freeze_lm_embeddings is on in which case we only train <image> and <eoc> embeddings
model.lang_encoder.get_input_embeddings().requires_grad_(True)
model.lang_encoder.get_output_embeddings().requires_grad_(True)

print(
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
)
Expand Down
4 changes: 4 additions & 0 deletions open_flamingo/src/flamingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
vis_dim: int,
cross_attn_every_n_layers: int = 1,
gradient_checkpointing: bool = False,
new_tokens: int = 2,
):
"""
Args:
Expand All @@ -34,6 +35,8 @@ def __init__(
vis_dim (int): Dimension of the visual features.
Visual features are projected to match this shape along the last dimension.
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
new_tokens (int, optional): Number of new tokens added to the tokenizer. Defaults to 2.
"""
super().__init__()
self.eoc_token_id = eoc_token_id
Expand All @@ -53,6 +56,7 @@ def __init__(
vis_hidden_size=self.vis_dim,
cross_attn_every_n_layers=cross_attn_every_n_layers,
gradient_checkpointing=gradient_checkpointing,
new_tokens=new_tokens,
)
self._use_gradient_checkpointing = gradient_checkpointing
self.perceiver._use_gradient_checkpointing = gradient_checkpointing
Expand Down
38 changes: 37 additions & 1 deletion open_flamingo/src/flamingo_lm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import torch.nn as nn
from .helpers import GatedCrossAttentionBlock
from .helpers import (
GatedCrossAttentionBlock,
FlamingoDecoupledEmbedding,
FlamingoDecoupledLinear,
)
from .utils import getattr_recursive, setattr_recursive


Expand Down Expand Up @@ -87,6 +91,7 @@ def init_flamingo(
vis_hidden_size,
cross_attn_every_n_layers,
gradient_checkpointing,
new_tokens,
):
"""
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
Expand All @@ -104,6 +109,37 @@ def init_flamingo(
)
self.init_flamingo_layers(gradient_checkpointing)
self.media_token_id = media_token_id

# modify the embedding layer to support decoupling
anas-awadalla marked this conversation as resolved.
Show resolved Hide resolved
input_embed_weights = self.get_input_embeddings().weight
self.set_input_embeddings(
FlamingoDecoupledEmbedding(
num_embeddings=input_embed_weights.shape[0],
num_additional_embeddings=new_tokens,
embedding_dim=input_embed_weights.shape[1],
partially_freeze=True,
)
)
self.get_input_embeddings().weight = input_embed_weights

out_embeds = FlamingoDecoupledLinear(
in_features=input_embed_weights.shape[1],
out_features=input_embed_weights.shape[0],
bias=self.get_output_embeddings().bias is not None,
out_additional_features=new_tokens,
partially_freeze=True,
)

if getattr(self.config, "tie_word_embeddings", True):
out_embeds.weight = input_embed_weights
else:
out_embeds.weight = self.get_output_embeddings().weight

if self.get_output_embeddings().bias is not None:
out_embeds.bias = self.get_output_embeddings().bias

self.set_output_embeddings(out_embeds)

self.initialized_flamingo = True
self._use_cached_vision_x = False

Expand Down
Loading
Loading