From 0bfba86fa6add481ff974520f04c9c199bddf2fe Mon Sep 17 00:00:00 2001 From: i-gao Date: Mon, 28 Aug 2023 19:17:52 -0700 Subject: [PATCH 01/24] add other models, deepspeed attempt --- open_flamingo/eval/eval_model.py | 110 +++++++++--- open_flamingo/eval/evaluate.py | 10 +- open_flamingo/eval/models/blip.py | 139 ++++++++------- open_flamingo/eval/models/idefics.py | 192 +++++++++++++++++++++ open_flamingo/eval/models/open_flamingo.py | 165 ++++++------------ 5 files changed, 419 insertions(+), 197 deletions(-) create mode 100644 open_flamingo/eval/models/idefics.py diff --git a/open_flamingo/eval/eval_model.py b/open_flamingo/eval/eval_model.py index 672d1fd0..c011deea 100644 --- a/open_flamingo/eval/eval_model.py +++ b/open_flamingo/eval/eval_model.py @@ -1,14 +1,16 @@ import abc -import argparse from typing import List from torch.nn.parallel import DistributedDataParallel as DDP from PIL import Image +from utils import get_autocast, get_cast_dtype +import torch +from transformers.modeling_outputs import CausalLMOutputWithPast class BaseEvalModel(abc.ABC): """Base class encapsulating functionality needed to evaluate a model.""" - def __init__(self, args: List[str]): + def __init__(self, model_args: List[str]): """Initialize model. Args: @@ -17,23 +19,99 @@ def __init__(self, args: List[str]): is non-empty. """ - def init_distributed(self): - """Wrap model as DDP.""" - self.model = DDP(self.model, device_ids=[self.device]) + def __init__(self, model_args): + assert "lm_path" in model_args, "All models require the lm_path argument" + self.device = ( + model_args["device"] + if ("device" in model_args and model_args["device"] >= 0) + else "cpu" + ) + precision = model_args.get("precision", "fp32") + self.lm_name = model_args["lm_path"].split("/")[-1] + self.autocast = get_autocast(precision) + self.cast_dtype = get_cast_dtype(precision) + + def _check_init(self): + """Finish model initialization.""" + assert hasattr(self, "model"), "Model has not been initialized" + self.model.eval() + self.model.to(self.device, dtype=self.cast_dtype) + assert hasattr(self, "tokenizer"), "Tokenizer has not been initialized" + self.tokenizer.padding_side = "left" + + def init_distributed(self, world_size=None, use_deepspeed=False): + """Wrap model as DDP or deepspeed.""" + if use_deepspeed: + import deepspeed + + self.ds_engine = deepspeed.init_inference( + self.model, + mp_size=world_size, + dtype=self.cast_dtype, + checkpoint=None, + replace_with_kernel_inject=True, + ) + self.model = self.ds_engine.module + else: + self.model = DDP(self.model, device_ids=[self.device]) def set_device(self, device): """Set device for model.""" self.device = device self.model = self.model.to(device) + def __call__( + self, + lang_x: torch.Tensor, + vision_x: torch.Tensor, + attention_mask: torch.Tensor, + past_key_values: torch.Tensor = None, + use_cache: bool = False, + ): + """ + Calls the forward function of the model. + Special logic to handle the case if past_key_values is not None: + then lang_x is assumed to contain the tokens to be generated + *excluding* the tokens already in past_key_values. + We then repeatedly call forward, updating the past_key_values. + """ + + def prepare_text( + self, + batch: List[List[str]], + padding="longest", + truncation=True, + max_length=2000, + add_special_tokens=True, + ): + """ + Prepare text for model. + + Args: + batch: list of text strings + padding: whether to pad the text + truncation: whether to truncate the text + max_length: maximum length of the text + + Returns: + input_ids: tensor of shape (B, T) + attention_mask: tensor of shape (B, T) + """ + + def prepare_images(self, batch: List[List[Image.Image]]): + """ + Prepare images for model. + Args: + batch: list of lists of PIL images + Returns: + tensor of shape (B, T, *, C, H, W) + """ + def get_outputs( self, batch_text: List[str], batch_images: List[List[Image.Image]], - min_generation_length: int, - max_generation_length: int, - num_beams: int, - length_penalty: float, + **decode_kwargs, ) -> List[str]: """Get outputs for a batch of images and text. @@ -51,20 +129,6 @@ def get_outputs( List of decoded output strings. """ - def vqa_prompt(self, question, answer=None) -> str: - """Get the prompt to use for VQA evaluation. If the answer is not provided, it should be left blank to be generated by the model. - - Returns: - The prompt to use for VQA. - """ - - def caption_prompt(self, caption=None) -> str: - """Get the prompt to use for caption evaluation. If the caption is not provided, it should be left blank to be generated by the model. - - Returns: - The prompt to use for captioning. - """ - def get_rank_classifications( self, batch_text: List[str], diff --git a/open_flamingo/eval/evaluate.py b/open_flamingo/eval/evaluate.py index a2f17eed..872668f3 100644 --- a/open_flamingo/eval/evaluate.py +++ b/open_flamingo/eval/evaluate.py @@ -388,6 +388,12 @@ action="store_true", help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", ) +parser.add_argument( + "--deepspeed", + default=False, + action="store_true", + help="Whether to use deepspeed for distributed inference.", +) def main(): @@ -403,7 +409,9 @@ def main(): args.local_rank, args.rank, args.world_size = world_info_from_env() device_id = init_distributed_device(args) eval_model.set_device(device_id) - eval_model.init_distributed() + eval_model.init_distributed( + world_size=args.world_size, use_deepspeed=args.deepspeed + ) if args.model != "open_flamingo" and args.shots != [0]: raise ValueError("Only 0 shot eval is supported for non-open_flamingo models") diff --git a/open_flamingo/eval/models/blip.py b/open_flamingo/eval/models/blip.py index 40ea124d..82cbdd14 100644 --- a/open_flamingo/eval/models/blip.py +++ b/open_flamingo/eval/models/blip.py @@ -4,97 +4,95 @@ import torch from transformers import Blip2Processor, Blip2ForConditionalGeneration -from open_flamingo.eval.eval_model import BaseEvalModel -from open_flamingo.eval.utils import unwrap_model +from models.eval_model import BaseEvalModel +from utils import unwrap_model +from transformers.modeling_outputs import CausalLMOutputWithPast class EvalModel(BaseEvalModel): - """BLIP-2 model evaluation. + """BLIP-2 model evaluation.""" - Attributes: - model (nn.Module): Underlying Torch model. - tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model. - device: Index of GPU to use, or the string "cpu" - """ - - def __init__(self, model_args): + def __init__(self, **model_args): assert ( "processor_path" in model_args and "lm_path" in model_args ), "BLIP-2 requires processor_path, lm_path, and device arguments to be specified" - + super().__init__(model_args) self.processor = Blip2Processor.from_pretrained(model_args["processor_path"]) self.model = Blip2ForConditionalGeneration.from_pretrained( model_args["lm_path"] ) - self.model.eval() - self.processor.tokenizer.padding_side = "left" - self.lm_name = model_args["lm_path"].split("/")[-1] - - def _prepare_images(self, batch: List[List[torch.Tensor]]) -> torch.Tensor: - """Preprocess images and stack them. - - Args: - batch: A list of lists of images. + self.tokenizer = self.processor.tokenizer + self._check_init() - Returns: - A Tensor of shape - (batch_size, channels, height, width). - """ + def prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor: batch_images = None assert all( len(example) == 1 for example in batch ), "BLIP-2 only supports one image per example" - for example in batch: - assert len(example) == 1, "BLIP-2 only supports one image per example" - batch_images = torch.cat( - [ - batch_images, - self.processor.image_processor(example, return_tensors="pt")[ - "pixel_values" - ], - ] - if batch_images is not None - else [ - self.processor.image_processor(example, return_tensors="pt")[ - "pixel_values" + if batch_images is None: + batch_images = self.processor.image_processor( + example, return_tensors="pt" + )["pixel_values"] + else: + batch_images = torch.cat( + [ + batch_images, + self.processor.image_processor(example, return_tensors="pt")[ + "pixel_values" + ], ] - ], - dim=0, + ) + if batch_images is not None: + batch_images = batch_images.to( + self.device, dtype=self.cast_dtype, non_blocking=True ) return batch_images - def get_outputs( + def prepare_text( self, - batch_text: List[str], - batch_images: List[List[Image.Image]], - min_generation_length: int, - max_generation_length: int, - num_beams: int, - length_penalty: float, - ) -> List[str]: - encodings = self.processor.tokenizer( - batch_text, - padding="longest", - truncation=True, + batch: List[List[str]], + padding="longest", + truncation=True, + max_length=2000, + add_special_tokens=True, + ): + encodings = self.tokenizer( + batch, + padding=padding, + truncation=truncation, return_tensors="pt", - max_length=2000, + max_length=max_length, + add_special_tokens=add_special_tokens, ) input_ids = encodings["input_ids"] attention_mask = encodings["attention_mask"] + input_ids = input_ids.to(self.device, non_blocking=True) + attention_mask = attention_mask.to(self.device, non_blocking=True) + return input_ids, attention_mask + + def get_outputs( + self, + batch_text: List[str], + batch_images: List[List[Image.Image]], + **decode_kwargs, + ) -> List[str]: + batch_images = self.prepare_images(batch_images) # (B, C, H, W) + input_ids, attention_mask = self.prepare_text(batch_text) with torch.inference_mode(): - outputs = unwrap_model(self.model).generate( - self._prepare_images(batch_images).to(self.device), - input_ids.to(self.device), - attention_mask=attention_mask.to(self.device), - max_new_tokens=max_generation_length, - min_new_tokens=min_generation_length, - num_beams=num_beams, - length_penalty=length_penalty, - ) + with self.autocast(): + outputs = unwrap_model(self.model).generate( + batch_images, + input_ids, + attention_mask=attention_mask, + **decode_kwargs, + ) - return self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True) + # Extract only the new gnerated tokens + outputs = outputs[:, len(input_ids[0]) :] + + return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) def get_vqa_prompt(self, question, answer=None) -> str: return ( @@ -104,6 +102,23 @@ def get_vqa_prompt(self, question, answer=None) -> str: def get_caption_prompt(self, caption=None) -> str: return f"A photo of {caption if caption is not None else ''}" + def __call__( + self, + lang_x: torch.Tensor, + vision_x: torch.Tensor, + attention_mask: torch.Tensor, + ): + with self.autocast(): + outputs = self.model( + pixel_values=vision_x, + input_ids=lang_x, + attention_mask=attention_mask, + ) + + # remove vision tokens + outputs.logits = outputs.logits[:, -lang_x.size(1) :, :] + return outputs + def get_rank_classifications( self, batch_text: List[str], diff --git a/open_flamingo/eval/models/idefics.py b/open_flamingo/eval/models/idefics.py new file mode 100644 index 00000000..8707eb91 --- /dev/null +++ b/open_flamingo/eval/models/idefics.py @@ -0,0 +1,192 @@ +from typing import List, Dict + +from PIL import Image +import torch +from einops import repeat + +from models.eval_model import BaseEvalModel +from transformers import IdeficsForVisionText2Text, AutoProcessor +from transformers.models.idefics.processing_idefics import ( + image_attention_mask_for_packed_input_ids, + incremental_to_binary_attention_mask, +) +from transformers.modeling_outputs import CausalLMOutputWithPast +from utils import unwrap_model + + +class EvalModel(BaseEvalModel): + """IDEFICS model evaluation.""" + + def __init__(self, **model_args): + assert ( + "lm_path" in model_args and "processor_path" in model_args + ), "IDEFICS requires lm_path and lm_tokenizer_path" + super().__init__(model_args) + self.model = IdeficsForVisionText2Text.from_pretrained(model_args["lm_path"]) + self.processor = AutoProcessor.from_pretrained(model_args["processor_path"]) + self.tokenizer = self.processor.tokenizer + self._check_init() + + def prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor: + batch_images = self.processor(batch)["pixel_values"] + if batch_images is not None: + batch_images = batch_images.to( + self.device, dtype=self.cast_dtype, non_blocking=True + ) + return batch_images + + def prepare_text( + self, + batch: List[List[str]], + padding="longest", + truncation=True, + max_length=2000, + add_special_tokens=True, + ): + # check to see if there any without wrapping it + for i, text in enumerate(batch): + if "" in text and "" not in text: + print( + "Warning: missing in text; inserting automatically." + ) + batch[i] = text.replace( + "", + "", + ) + + encodings = self.tokenizer( + batch, + padding=padding, + truncation=truncation, + return_tensors="pt", + max_length=max_length, + add_special_tokens=add_special_tokens, + ) + input_ids, attention_mask = encodings["input_ids"], encodings["attention_mask"] + input_ids = input_ids.to(self.device, non_blocking=True) + attention_mask = attention_mask.to(self.device, non_blocking=True) + return input_ids, attention_mask + + def _compute_image_attention_mask(self, batch_tokens: torch.Tensor) -> torch.Tensor: + """ + From: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/idefics/processing_idefics.py + """ + max_num_images = torch.max(torch.sum(batch_tokens == 32001, dim=-1)).item() + at_least_one_image = max_num_images > 0 + if at_least_one_image: + image_attention_mask, _ = image_attention_mask_for_packed_input_ids( + batch_tokens, self.tokenizer + ) + image_attention_mask = incremental_to_binary_attention_mask( + image_attention_mask, num_classes=max_num_images + ) + else: + # in full language mode we set the image mask to all-0s + image_attention_mask = torch.zeros( + batch_tokens.shape[0], batch_tokens.shape[1], 1, dtype=torch.bool + ) + return image_attention_mask + + def get_rank_classifications( + self, + batch_text: List[str], + batch_images: List[List[Image.Image]], + all_class_names: List[str], + use_cache: bool, + normalize_length: bool, + ): + """ + Returns a (B, |all_class_names|) tensor containing the logprobs for each class name. + """ + raise NotImplementedError + + def get_outputs( + self, + batch_text: List[str], + batch_images: List[List[Image.Image]], + **decode_kwargs, + ) -> List[str]: + batch_images = self.prepare_images(batch_images) + input_ids, attention_mask = self.prepare_text(batch_text) + image_attention_mask = self._compute_image_attention_mask(input_ids) + + with torch.inference_mode(): + with self.autocast(): + outputs = unwrap_model(self.model).generate( + pixel_values=batch_images, + image_attention_mask=image_attention_mask, + input_ids=input_ids, + attention_mask=attention_mask, + **decode_kwargs, + ) + + # Extract only the new gnerated tokens + outputs = outputs[:, len(input_ids[0]) :] + return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + + def __call__( + self, + lang_x: torch.Tensor, + vision_x: torch.Tensor, + attention_mask: torch.Tensor, + past_key_values: torch.Tensor = None, + use_cache: bool = False, + ): + image_attention_mask = self._compute_image_attention_mask(lang_x) + + # standard forward pass + if past_key_values is None: + with self.autocast(): + outputs = self.model( + pixel_values=vision_x, + image_attention_mask=image_attention_mask, + input_ids=lang_x, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + ) + return outputs + + # loop to handle updating past_key_values + logits = [] + for token_idx in range(lang_x.shape[1]): + _lang_x = lang_x[:, token_idx].reshape((-1, 1)) + if attention_mask is not None: + _attention_mask = attention_mask[:, token_idx].reshape((-1, 1)) + else: + _attention_mask = None + + with self.autocast(): + outputs = self.model( + pixel_values=vision_x, + image_attention_mask=image_attention_mask, + input_ids=_lang_x, + attention_mask=_attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + + past_key_values = outputs.past_key_values + logits.append(outputs.logits) + + logits = torch.cat(logits, dim=1) + return CausalLMOutputWithPast( + logits=logits, + past_key_values=past_key_values, + ) + + def get_vqa_prompt(self, question, answer=None) -> str: + # TODO: handle prefix prompts + return f"Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" + + def get_caption_prompt(self, caption=None) -> str: + # TODO: handle prefix prompts + return f"Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}" + + def get_imagenet_prompt(self, label=None) -> str: + # TODO: handle prefix prompts + return f"Output:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}" + + def get_hateful_memes_prompt(self, text, label=None) -> str: + # TODO: handle prefix prompts + return f"is an image with: '{text}' written on it. Is it hateful? Answer: {label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}" diff --git a/open_flamingo/eval/models/open_flamingo.py b/open_flamingo/eval/models/open_flamingo.py index 63ac0f14..0fe848ea 100644 --- a/open_flamingo/eval/models/open_flamingo.py +++ b/open_flamingo/eval/models/open_flamingo.py @@ -11,13 +11,7 @@ class EvalModel(BaseEvalModel): - """OpenFlamingo model evaluation. - - Attributes: - model (nn.Module): Underlying Torch model. - tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model. - device: Index of GPU to use, or the string "CPU" - """ + """OpenFlamingo model evaluation.""" def __init__(self, model_args): assert ( @@ -27,15 +21,8 @@ def __init__(self, model_args): and "lm_tokenizer_path" in model_args and "cross_attn_every_n_layers" in model_args and "vision_encoder_pretrained" in model_args - and "precision" in model_args - ), "OpenFlamingo requires vision_encoder_path, lm_path, device, checkpoint_path, lm_tokenizer_path, cross_attn_every_n_layers, vision_encoder_pretrained, and precision arguments to be specified" - - self.device = ( - model_args["device"] - if ("device" in model_args and model_args["device"] >= 0) - else "cpu" - ) - + ), "OpenFlamingo requires vision_encoder_path, lm_path, device, checkpoint_path, lm_tokenizer_path, cross_attn_every_n_layers, vision_encoder_pretrained arguments to be specified" + super().__init__(model_args) ( self.model, self.image_processor, @@ -47,31 +34,14 @@ def __init__(self, model_args): model_args["lm_tokenizer_path"], cross_attn_every_n_layers=int(model_args["cross_attn_every_n_layers"]), ) - checkpoint = torch.load(model_args["checkpoint_path"], map_location=self.device) + checkpoint = torch.load(model_args["checkpoint_path"], map_location="cpu") if "model_state_dict" in checkpoint: checkpoint = checkpoint["model_state_dict"] checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()} self.model.load_state_dict(checkpoint, strict=False) - self.model.to(self.device) - self.model.eval() - self.tokenizer.padding_side = "left" + self._check_init() - self.lm_name = model_args["lm_path"].split("/")[-1] - - # autocast - self.autocast = get_autocast(model_args["precision"]) - self.cast_dtype = get_cast_dtype(model_args["precision"]) - - def _prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor: - """ - Convert images to tensors, reshape them, and stack them. - Args: - batch: A list of lists of images. - Returns: - preprocessed images (tensors) or None - shape (B, T_img, F, C, H, W) - None if no images in batch - """ + def prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor: images_per_example = max(len(x) for x in batch) batch_images = None for iexample, example in enumerate(batch): @@ -89,67 +59,47 @@ def _prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor: ) return batch_images - def _prepare_text( + def prepare_text( self, batch: List[List[str]], padding="longest", truncation=True, max_length=2000, + add_special_tokens=True, ): - """ - Tokenize the text and stack them. - Args: - batch: A list of lists of strings. - Returns: - input_ids (tensor) - shape (B, T_txt) - attention_mask (tensor) - shape (B, T_txt) - """ encodings = self.tokenizer( batch, padding=padding, truncation=truncation, return_tensors="pt", max_length=max_length, + add_special_tokens=add_special_tokens, ) input_ids, attention_mask = encodings["input_ids"], encodings["attention_mask"] - input_ids = input_ids.to(self.device, dtype=self.cast_dtype, non_blocking=True) - attention_mask = attention_mask.to( - self.device, dtype=self.cast_dtype, non_blocking=True - ) + input_ids = input_ids.to(self.device, non_blocking=True) + attention_mask = attention_mask.to(self.device, non_blocking=True) return input_ids, attention_mask.bool() def get_outputs( self, batch_text: List[str], batch_images: List[List[Image.Image]], - min_generation_length: int, - max_generation_length: int, - num_beams: int, - length_penalty: float, + **decode_kwargs, ) -> List[str]: - """ - Get generation outputs. - """ - batch_images = self._prepare_images(batch_images) - input_ids, attention_mask = self._prepare_text(batch_text) + batch_images = self.prepare_images(batch_images) # (B, T, 1, C, H, W) + input_ids, attention_mask = self.prepare_text(batch_text) with torch.inference_mode(): with self.autocast(): outputs = unwrap_model(self.model).generate( - batch_images, - input_ids, - attention_mask, - min_new_tokens=min_generation_length, - max_new_tokens=max_generation_length, - num_beams=num_beams, - length_penalty=length_penalty, + vision_x=batch_images, + lang_x=input_ids, + attention_mask=attention_mask, + **decode_kwargs, ) # Extract only the new gnerated tokens outputs = outputs[:, len(input_ids[0]) :] - return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) def get_rank_classifications( @@ -163,8 +113,8 @@ def get_rank_classifications( """ Returns a (B, |all_class_names|) tensor containing the logprobs for each class name. """ - batch_images = self._prepare_images(batch_images) - ctx_input_ids, ctx_attention_mask = self._prepare_text(batch_text) + batch_images = self.prepare_images(batch_images) + ctx_input_ids, ctx_attention_mask = self.prepare_text(batch_text) # Cache the context if use_cache: @@ -173,13 +123,14 @@ def get_rank_classifications( input_ids=ctx_input_ids, vision_x=batch_images, ) - precomputed = self.__call__( - vision_x=None, - lang_x=ctx_input_ids, - attention_mask=ctx_attention_mask, - clear_conditioned_layers=False, - use_cache=True, - ) + with torch.inference_mode(): + precomputed = self.__call__( + vision_x=None, + lang_x=ctx_input_ids, + attention_mask=ctx_attention_mask, + clear_conditioned_layers=False, + use_cache=True, + ) precomputed_logits = precomputed.logits precomputed_pkvs = precomputed.past_key_values else: @@ -218,13 +169,14 @@ def get_rank_classifications( _vision_x = None # Call forward to get the logits - outputs = self.__call__( - vision_x=_vision_x, - lang_x=_lang_x, - attention_mask=_attention_mask, - clear_conditioned_layers=(not use_cache), - past_key_values=precomputed_pkvs, - ) + with torch.inference_mode(): + outputs = self.__call__( + vision_x=_vision_x, + lang_x=_lang_x, + attention_mask=_attention_mask, + clear_conditioned_layers=(not use_cache), + past_key_values=precomputed_pkvs, + ) # Get the logits of the classname # logits shape is either (B, num_tokens_in_classname, vocab_len) with use_cache @@ -262,25 +214,17 @@ def __call__( clear_conditioned_layers: bool = False, use_cache: bool = False, ): - """ - Calls the forward function of the model. - Special logic to handle the case if past_key_values is not None: - then lang_x is assumed to contain the tokens to be generated - *excluding* the tokens already in past_key_values. - We then repeatedly call forward, updating the past_key_values. - """ # standard forward pass if past_key_values is None: - with torch.inference_mode(): - with self.autocast(): - outputs = self.model( - vision_x=vision_x, - lang_x=lang_x, - attention_mask=attention_mask, - clear_conditioned_layers=clear_conditioned_layers, - past_key_values=past_key_values, - use_cache=use_cache, - ) + with self.autocast(): + outputs = self.model( + vision_x=vision_x, + lang_x=lang_x, + attention_mask=attention_mask, + clear_conditioned_layers=clear_conditioned_layers, + past_key_values=past_key_values, + use_cache=use_cache, + ) return outputs # loop to handle updating past_key_values @@ -292,16 +236,15 @@ def __call__( else: _attention_mask = None - with torch.inference_mode(): - with self.autocast(): - outputs = self.model( - vision_x=vision_x, - lang_x=_lang_x, - attention_mask=_attention_mask, - clear_conditioned_layers=False, - past_key_values=past_key_values, - use_cache=True, - ) + with self.autocast(): + outputs = self.model( + vision_x=vision_x, + lang_x=_lang_x, + attention_mask=_attention_mask, + clear_conditioned_layers=False, + past_key_values=past_key_values, + use_cache=True, + ) past_key_values = outputs.past_key_values logits.append(outputs.logits) From ff0bdf5842f25465538e42a396afeebf1a633715 Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Thu, 31 Aug 2023 14:35:00 -0700 Subject: [PATCH 02/24] fixed embedding freezing using Idefics among other things --- open_flamingo/eval/models/open_flamingo.py | 17 +- open_flamingo/src/factory.py | 19 +-- open_flamingo/src/flamingo.py | 4 + open_flamingo/src/flamingo_lm.py | 44 ++++- open_flamingo/src/helpers.py | 185 +++++++++++++++++++++ open_flamingo/train/train.py | 20 +-- open_flamingo/train/train_utils.py | 25 --- 7 files changed, 253 insertions(+), 61 deletions(-) diff --git a/open_flamingo/eval/models/open_flamingo.py b/open_flamingo/eval/models/open_flamingo.py index 63ac0f14..ce3a6f0b 100644 --- a/open_flamingo/eval/models/open_flamingo.py +++ b/open_flamingo/eval/models/open_flamingo.py @@ -8,7 +8,7 @@ from open_flamingo.src.factory import create_model_and_transforms from open_flamingo.eval.utils import unwrap_model, get_autocast, get_cast_dtype from transformers.modeling_outputs import CausalLMOutputWithPast - +from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint class EvalModel(BaseEvalModel): """OpenFlamingo model evaluation. @@ -47,11 +47,12 @@ def __init__(self, model_args): model_args["lm_tokenizer_path"], cross_attn_every_n_layers=int(model_args["cross_attn_every_n_layers"]), ) - checkpoint = torch.load(model_args["checkpoint_path"], map_location=self.device) - if "model_state_dict" in checkpoint: - checkpoint = checkpoint["model_state_dict"] - checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()} - self.model.load_state_dict(checkpoint, strict=False) + self.model = load_state_dict_from_zero_checkpoint(self.model, model_args["checkpoint_path"]) + # checkpoint = torch.load(model_args["checkpoint_path"], map_location=self.device) + # if "model_state_dict" in checkpoint: + # checkpoint = checkpoint["model_state_dict"] + # checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()} + # self.model.load_state_dict(checkpoint, strict=False) self.model.to(self.device) self.model.eval() self.tokenizer.padding_side = "left" @@ -114,9 +115,9 @@ def _prepare_text( max_length=max_length, ) input_ids, attention_mask = encodings["input_ids"], encodings["attention_mask"] - input_ids = input_ids.to(self.device, dtype=self.cast_dtype, non_blocking=True) + input_ids = input_ids.to(self.device, non_blocking=True) attention_mask = attention_mask.to( - self.device, dtype=self.cast_dtype, non_blocking=True + self.device, non_blocking=True ) return input_ids, attention_mask.bool() diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index d8a86c6c..13917f60 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -16,7 +16,6 @@ def create_model_and_transforms( cross_attn_every_n_layers: int = 1, use_local_files: bool = False, decoder_layers_attr_name: str = None, - freeze_lm_embeddings: bool = False, cache_dir: Optional[str] = None, **flamingo_kwargs, ): @@ -32,7 +31,6 @@ def create_model_and_transforms( cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. use_local_files (bool, optional): whether to use local files. Defaults to False. decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. - freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver. cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights. Returns: Flamingo: Flamingo model from pretrained vision and language encoders @@ -57,10 +55,12 @@ def create_model_and_transforms( text_tokenizer.add_special_tokens( {"additional_special_tokens": ["<|endofchunk|>", ""]} ) + new_tokens = 2 if text_tokenizer.pad_token is None: # Issue: GPT models don't have a pad token, which we use to # modify labels for the loss. text_tokenizer.add_special_tokens({"pad_token": ""}) + new_tokens += 1 lang_encoder = AutoModelForCausalLM.from_pretrained( lang_encoder_path, @@ -87,9 +87,6 @@ def set_input_embeddings(self, new_embeddings): if decoder_layers_attr_name is None: decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) - lang_encoder.resize_token_embeddings( - len(text_tokenizer) - ) model = Flamingo( vision_encoder, @@ -100,21 +97,17 @@ def set_input_embeddings(self, new_embeddings): "width" ], cross_attn_every_n_layers=cross_attn_every_n_layers, + new_tokens=new_tokens, # number of tokens embeddings to train **flamingo_kwargs, ) # Freeze all parameters - model.requires_grad_(False) - assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 + model.vision_encoder.requires_grad_(False) + model.lang_encoder.requires_grad_(False) - # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings - model.perceiver.requires_grad_(True) + # Unfreeze gated_cross_attn_layers model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) - # TODO: FIX this. Currently we are just training all embeddings unless freeze_lm_embeddings is on in which case we only train and embeddings - model.lang_encoder.get_input_embeddings().requires_grad_(True) - model.lang_encoder.get_output_embeddings().requires_grad_(True) - print( f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" ) diff --git a/open_flamingo/src/flamingo.py b/open_flamingo/src/flamingo.py index 9a67cfed..ee02b65f 100644 --- a/open_flamingo/src/flamingo.py +++ b/open_flamingo/src/flamingo.py @@ -24,6 +24,7 @@ def __init__( vis_dim: int, cross_attn_every_n_layers: int = 1, gradient_checkpointing: bool = False, + new_tokens: int = 2, ): """ Args: @@ -34,6 +35,8 @@ def __init__( vis_dim (int): Dimension of the visual features. Visual features are projected to match this shape along the last dimension. cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False. + new_tokens (int, optional): Number of new tokens added to the tokenizer. Defaults to 2. """ super().__init__() self.eoc_token_id = eoc_token_id @@ -53,6 +56,7 @@ def __init__( vis_hidden_size=self.vis_dim, cross_attn_every_n_layers=cross_attn_every_n_layers, gradient_checkpointing=gradient_checkpointing, + new_tokens=new_tokens, ) self._use_gradient_checkpointing = gradient_checkpointing self.perceiver._use_gradient_checkpointing = gradient_checkpointing diff --git a/open_flamingo/src/flamingo_lm.py b/open_flamingo/src/flamingo_lm.py index c4933e9d..643da8a3 100644 --- a/open_flamingo/src/flamingo_lm.py +++ b/open_flamingo/src/flamingo_lm.py @@ -1,5 +1,9 @@ import torch.nn as nn -from .helpers import GatedCrossAttentionBlock +from .helpers import ( + GatedCrossAttentionBlock, + FlamingoDecoupledEmbedding, + FlamingoDecoupledLinear, +) from .utils import getattr_recursive, setattr_recursive @@ -87,6 +91,7 @@ def init_flamingo( vis_hidden_size, cross_attn_every_n_layers, gradient_checkpointing, + new_tokens, ): """ Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. @@ -104,6 +109,43 @@ def init_flamingo( ) self.init_flamingo_layers(gradient_checkpointing) self.media_token_id = media_token_id + + # modify the embedding layer to support decoupling + input_embed_weights = self.get_input_embeddings().weight + self.set_input_embeddings( + FlamingoDecoupledEmbedding( + num_embeddings=input_embed_weights.shape[0], + num_additional_embeddings=new_tokens, + embedding_dim=input_embed_weights.shape[1], + partially_freeze=True, + ) + ) + self.get_input_embeddings().weight = input_embed_weights + + # create a get_output_embeddings() / set_output_embeddings() method if it doesn't exist + # this is needed for compatibility + if not hasattr(self, "get_output_embeddings"): + self.get_output_embeddings = lambda: self.lm_head + self.set_output_embeddings = lambda x: setattr(self, "lm_head", x) + + out_embeds = FlamingoDecoupledLinear( + in_features=input_embed_weights.shape[1], + out_features=input_embed_weights.shape[0], + bias=self.get_output_embeddings().bias is not None, + out_additional_features=new_tokens, + partially_freeze=True, + ) + + if getattr(self.config, "tie_word_embeddings", True): + out_embeds.weight = input_embed_weights + else: + out_embeds.weight = self.get_output_embeddings().weight + + if self.get_output_embeddings().bias is not None: + out_embeds.bias = self.get_output_embeddings().bias + + self.set_output_embeddings(out_embeds) + self.initialized_flamingo = True self._use_cached_vision_x = False diff --git a/open_flamingo/src/helpers.py b/open_flamingo/src/helpers.py index 239503f8..c7adf303 100644 --- a/open_flamingo/src/helpers.py +++ b/open_flamingo/src/helpers.py @@ -6,6 +6,7 @@ from einops import rearrange, repeat from einops_exts import rearrange_many from torch import einsum, nn +from torch.nn import functional as F def exists(val): @@ -277,3 +278,187 @@ def forward( x = self.ff(x) * self.ff_gate.tanh() + x return x + + +# Both FlamingoDecoupledEmbedding and FlamingoDecoupledLinear are taken from https://github.com/huggingface/transformers/blob/v4.32.1/src/transformers/models/idefics/modeling_idefics.py and renamed for clarity + + +class FlamingoDecoupledEmbedding(nn.Embedding): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, + then it will create `num_additional_embeddings` additional parameters that are always trained. If + `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. + """ + + def __init__( + self, + num_embeddings, + num_additional_embeddings, + embedding_dim, + partially_freeze=False, + device=None, + dtype=None, + padding_idx=None, + **kwargs, + ) -> None: + """ + Args: + num_embeddings (`int`): + Size of the dictionary of embeddings + num_additional_embeddings (`int`): + Number of additional embeddings. Only useful when you `partially_freeze=True`. + embedding_dim (`int`): + The size of each embedding vector + partially_freeze: (`bool`, *optional*, defaults to `False`): + If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen. + padding_idx (`int`, *optional*): + The padding index (needs to be less than num_embeddings) + + Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, + `max_norm` or `norm_type`. We are not supporting these. + """ + if padding_idx is not None and padding_idx > num_embeddings: + raise ValueError( + f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}" + ) + super().__init__( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + padding_idx=padding_idx, + **kwargs, + ) + self.num_embeddings = num_embeddings + self.padding_idx = padding_idx + self.num_additional_embeddings = num_additional_embeddings + self.partially_freeze = partially_freeze + + if partially_freeze: + self.weight.requires_grad_(False) + + if self.num_additional_embeddings > 0: + self.additional_embedding = nn.Embedding( + num_embeddings=self.num_additional_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ) + + def forward(self, input_ids): + """ + we have 2 embeddings, with different indices - one pretrained self.weight and another + self.additional_embedding.weight that is being trained. + + in order to make a lookup of the input ids, we: + 1. find out the indices of the entries belonging to the 2nd embedding + 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd + embedding starts from 0 and not num_embeddings + 3. perform the 2nd embedding lookup + 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index + 5. perform the 1st embedding lookup + 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup + + note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but + then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices - + i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are + usually relatively short it's probably not faster or if faster not by much - but might be a good idea to + measure. + + """ + if self.num_additional_embeddings == 0: + return F.embedding(input_ids, self.weight) + + # Clone so that we don't modify the original input_ids later on + input_ids = input_ids.clone() + additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) + input_ids_additional_vocab = input_ids[additional_vocab_indices] + additional_embeddings = self.additional_embedding( + input_ids_additional_vocab - self.num_embeddings + ) + + # for successful lookup replace input_ids with 0, the results of these will be discarded anyway + input_ids[additional_vocab_indices] = 0 + full_vector = F.embedding(input_ids, self.weight) + + # overwrite the records with high indices + full_vector[additional_vocab_indices] = additional_embeddings + + return full_vector + + def extra_repr(self) -> str: + return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( + self.num_embeddings, + self.num_additional_embeddings, + self.embedding_dim, + self.partially_freeze, + ) + + +class FlamingoDecoupledLinear(nn.Linear): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, + then it will create `out_additional_features * in_features` additional parameters that are always trained. If + `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`. + """ + + def __init__( + self, + in_features: int, + out_features: int, + out_additional_features: int = 0, + bias: bool = True, + partially_freeze: bool = True, + device=None, + dtype=None, + ) -> None: + """ + out_additional_features: int. Number of additional trainable dimensions. Only makes sense when + `partially_freeze=True`. partially_freeze: bool. If True, the regular `weight` will be frozen and extra + parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear. + """ + super().__init__(in_features, out_features, bias, device, dtype) + self.out_additional_features = out_additional_features + self.partially_freeze = partially_freeze + + self.in_features = in_features + self.out_features = out_features + + if partially_freeze: + self.weight.requires_grad_(False) + if bias: + self.bias.requires_grad_(False) + + if out_additional_features > 0: + self.additional_fc = nn.Linear( + in_features=in_features, + out_features=out_additional_features, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = F.linear(input, self.weight, self.bias) + + if self.out_additional_features > 0: + additional_features = F.linear( + input, self.additional_fc.weight, self.additional_fc.bias + ) + output = torch.cat((output, additional_features), -1) + + return output + + def extra_repr(self) -> str: + """Overwriting `nn.Linear.extra_repr` to include new parameters.""" + return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format( + self.in_features, + self.out_features, + self.out_additional_features, + self.bias is not None, + self.partially_freeze, + ) diff --git a/open_flamingo/train/train.py b/open_flamingo/train/train.py index dd9326cf..bfed4f7a 100644 --- a/open_flamingo/train/train.py +++ b/open_flamingo/train/train.py @@ -122,11 +122,6 @@ def main(): help="we define an 'epoch' as a fixed number of examples (train_num_samples_mmc4, train_num_samples_laion), not a pass through the entire dataset", ) parser.add_argument("--offline", action="store_true") - parser.add_argument( - "--freeze_lm_embeddings", - action="store_true", - help="if True, we freeze the LM embeddings during training. Otherwise, we train the and <|endofchunk|> embeddings.", - ) parser.add_argument( "--logging_steps", type=int, default=100, help="log loss every n steps" ) @@ -237,8 +232,8 @@ def main(): ) args = parser.parse_args() - - args.local_rank = int(os.environ.get('LOCAL_RANK', -1)) # for deepspeed + + args.local_rank = int(os.environ.get("LOCAL_RANK", -1)) # for deepspeed # Validate args if args.laion_shards.startswith("s3"): @@ -253,8 +248,7 @@ def main(): if args.fsdp and not args.fsdp_use_orig_params: print( "Warning: FSDP is running without fsdp_use_orig_params flag. " - + "This is not recommended because it means we will use uniform weight decay" - + " and train all embeddings, not just the newly added ones. " + + "This is not recommended because it means we will use uniform weight decay." + "Note: OPT models are not compatible with fsdp_use_orig_params flag." ) @@ -265,9 +259,6 @@ def main(): + "Copy and paste the code from the _optim_utils.py in this repo into the torch file." + "The main issue was the missing group kwarg on line 1596 in _all_gather_optim_state." ) - - if args.deepspeed and args.freeze_lm_embeddings: - raise ValueError("DeepSpeed is not supported with partially frozen LM embeddings") assert (args.train_num_samples_laion // args.batch_size_laion) == ( args.train_num_samples_mmc4 // args.batch_size_mmc4 @@ -336,7 +327,6 @@ def main(): cross_attn_every_n_layers=args.cross_attn_every_n_layers, use_local_files=args.offline, gradient_checkpointing=args.gradient_checkpointing, - freeze_lm_embeddings=args.freeze_lm_embeddings, ) random_seed(args.seed, args.rank) @@ -446,7 +436,9 @@ def main(): ddp_model = DDP(model, device_ids=[device_id]) # Initialize optimizer - params_to_optimize = ddp_model.named_parameters() if not args.deepspeed else model.named_parameters() + params_to_optimize = ( + ddp_model.named_parameters() if not args.deepspeed else model.named_parameters() + ) params_to_optimize = list( filter( lambda x: x[1].requires_grad diff --git a/open_flamingo/train/train_utils.py b/open_flamingo/train/train_utils.py index 188cb325..30f7398f 100644 --- a/open_flamingo/train/train_utils.py +++ b/open_flamingo/train/train_utils.py @@ -166,31 +166,6 @@ def train_one_epoch( else: (divided_loss_mmc4 * args.loss_multiplier_mmc4).backward() - # TODO: investigate whether this is necessary - if (args.freeze_lm_embeddings) and ( - not args.fsdp or args.fsdp_use_orig_params - ): - # Mask gradients for input embeddings s.t. we only update the added tokens and <|endofchunk|> - if args.fsdp: - embed_grad = model.lang_encoder.get_input_embeddings().weight.grad - else: - embed_grad = ( - model.module.lang_encoder.get_input_embeddings().weight.grad - ) - zero_mask = torch.zeros_like(embed_grad) - zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) - zero_mask[endofchunk_token_id] = torch.ones_like( - zero_mask[endofchunk_token_id] - ) - if args.fsdp: - model.lang_encoder.get_input_embeddings().weight.grad = ( - embed_grad * zero_mask - ) - else: - model.module.lang_encoder.get_input_embeddings().weight.grad = ( - embed_grad * zero_mask - ) - # clip gradient norm if args.fsdp: """ From f7b45a428363e5a2eb5e29a2e2e2311a081e0e45 Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Thu, 31 Aug 2023 14:42:20 -0700 Subject: [PATCH 03/24] Reverted specified files to the version in origin/main --- open_flamingo/eval/models/open_flamingo.py | 17 +++++------ open_flamingo/scripts/run_train.sh | 34 ++++++++-------------- 2 files changed, 20 insertions(+), 31 deletions(-) diff --git a/open_flamingo/eval/models/open_flamingo.py b/open_flamingo/eval/models/open_flamingo.py index ce3a6f0b..63ac0f14 100644 --- a/open_flamingo/eval/models/open_flamingo.py +++ b/open_flamingo/eval/models/open_flamingo.py @@ -8,7 +8,7 @@ from open_flamingo.src.factory import create_model_and_transforms from open_flamingo.eval.utils import unwrap_model, get_autocast, get_cast_dtype from transformers.modeling_outputs import CausalLMOutputWithPast -from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + class EvalModel(BaseEvalModel): """OpenFlamingo model evaluation. @@ -47,12 +47,11 @@ def __init__(self, model_args): model_args["lm_tokenizer_path"], cross_attn_every_n_layers=int(model_args["cross_attn_every_n_layers"]), ) - self.model = load_state_dict_from_zero_checkpoint(self.model, model_args["checkpoint_path"]) - # checkpoint = torch.load(model_args["checkpoint_path"], map_location=self.device) - # if "model_state_dict" in checkpoint: - # checkpoint = checkpoint["model_state_dict"] - # checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()} - # self.model.load_state_dict(checkpoint, strict=False) + checkpoint = torch.load(model_args["checkpoint_path"], map_location=self.device) + if "model_state_dict" in checkpoint: + checkpoint = checkpoint["model_state_dict"] + checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()} + self.model.load_state_dict(checkpoint, strict=False) self.model.to(self.device) self.model.eval() self.tokenizer.padding_side = "left" @@ -115,9 +114,9 @@ def _prepare_text( max_length=max_length, ) input_ids, attention_mask = encodings["input_ids"], encodings["attention_mask"] - input_ids = input_ids.to(self.device, non_blocking=True) + input_ids = input_ids.to(self.device, dtype=self.cast_dtype, non_blocking=True) attention_mask = attention_mask.to( - self.device, non_blocking=True + self.device, dtype=self.cast_dtype, non_blocking=True ) return input_ids, attention_mask.bool() diff --git a/open_flamingo/scripts/run_train.sh b/open_flamingo/scripts/run_train.sh index dce882e1..8d45355e 100644 --- a/open_flamingo/scripts/run_train.sh +++ b/open_flamingo/scripts/run_train.sh @@ -1,11 +1,7 @@ #!/bin/bash #SBATCH --nodes 1 -#SBATCH --ntasks-per-node=6 +#SBATCH --ntasks-per-node=8 #SBATCH --gpus-per-task=1 -#SBATCH --account=efml -#SBATCH --partition=gpu -#SBATCH --time=48:00:00 -#SBATCH --job-name=flamingo export PYTHONFAULTHANDLER=1 export CUDA_LAUNCH_BLOCKING=0 @@ -13,30 +9,24 @@ export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) export MASTER_PORT=15000 export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` -export HF_DATASETS_CACHE="/gscratch/efml/anasa2/.huggingface" TRANSFORMERS_CACHE="/gscratch/efml/anasa2/.huggingface" export PYTHONPATH="$PYTHONPATH:open_flamingo" -srun --cpu_bind=v --accel-bind=gn python - - - -deepspeed open_flamingo/open_flamingo/train/train.py \ - --lm_path meta-llama/Llama-2-13b \ - --tokenizer_path meta-llama/Llama-2-13b \ - --cross_attn_every_n_layers 4 \ +srun --cpu_bind=v --accel-bind=gn python open_flamingo/open_flamingo/train/train.py \ + --lm_path anas-awadalla/mpt-1b-redpajama-200b \ + --tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \ + --cross_attn_every_n_layers 1 \ --dataset_resampled \ - --batch_size_mmc4 16 \ - --batch_size_laion 32 \ - --deepspeed \ + --batch_size_mmc4 32 \ + --batch_size_laion 64 \ --train_num_samples_mmc4 125000\ --train_num_samples_laion 250000 \ --loss_multiplier_laion 0.2 \ --workers=4 \ - --run_name "deepspeed" \ + --run_name OpenFlamingo-3B-vitl-mpt1b \ --num_epochs 480 \ - --warmup_steps 0 \ - --mmc4_textsim_threshold 0.0 \ - --laion_shards "/mmfs1/gscratch/efml/anasa2/laion-samples/{000000..000001}.tar" \ - --mmc4_shards "/mmfs1/gscratch/efml/anasa2/mmc4-samples/shard_{0..1}-000000000.tar" \ + --warmup_steps 1875 \ + --mmc4_textsim_threshold 0.24 \ + --laion_shards "/path/to/shards/shard-{0000..0999}.tar" \ + --mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \ --gradient_checkpointing \ --report_to_wandb \ From 176bbc0cc5cfb9fd63112d42da8e4a139e835e84 Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Thu, 31 Aug 2023 14:43:00 -0700 Subject: [PATCH 04/24] added deepspeed to reqs --- requirements-training.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-training.txt b/requirements-training.txt index 79ff0bc9..8b46a831 100644 --- a/requirements-training.txt +++ b/requirements-training.txt @@ -3,3 +3,4 @@ braceexpand webdataset tqdm wandb +deepspeed \ No newline at end of file From 71ddca66dc6bc4d0cee42f428bc61cf3d2f27cf4 Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Thu, 31 Aug 2023 16:54:42 -0700 Subject: [PATCH 05/24] remove stage3 16 bit weights and local_rank arg --- open_flamingo/train/train.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/open_flamingo/train/train.py b/open_flamingo/train/train.py index bfed4f7a..3418efbb 100644 --- a/open_flamingo/train/train.py +++ b/open_flamingo/train/train.py @@ -232,8 +232,7 @@ def main(): ) args = parser.parse_args() - - args.local_rank = int(os.environ.get("LOCAL_RANK", -1)) # for deepspeed + args.local_rank, args.rank, args.world_size = world_info_from_env() # Validate args if args.laion_shards.startswith("s3"): @@ -269,7 +268,6 @@ def main(): os.environ["WANDB_MODE"] = "offline" os.environ["TRANSFORMERS_OFFLINE"] = "1" - args.local_rank, args.rank, args.world_size = world_info_from_env() if args.deepspeed: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) @@ -285,7 +283,6 @@ def main(): "stage3_param_persistence_threshold": 1e4, "stage3_max_live_parameters": 3e7, "stage3_prefetch_bucket_size": 3e7, - "stage3_gather_16bit_weights_on_model_save": True, "memory_efficient_linear": False, } ds_config = { From 76f4f8c892496a7482bfa4cd2c116b9479552dcd Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Thu, 31 Aug 2023 16:58:29 -0700 Subject: [PATCH 06/24] move get_embed func to factory.py and restore req grad call --- open_flamingo/src/factory.py | 7 ++++++- open_flamingo/src/flamingo_lm.py | 6 ------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 13917f60..3a194eab 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -80,6 +80,10 @@ def set_input_embeddings(self, new_embeddings): self.transformer.wte = new_embeddings extend_instance(lang_encoder, EmbeddingFnMixin) + + if not hasattr(lang_encoder, "get_output_embeddings"): + lang_encoder.get_output_embeddings = lambda: lang_encoder.lm_head + lang_encoder.set_output_embeddings = lambda x: setattr(lang_encoder, "lm_head", x) # convert LM to FlamingoLM extend_instance(lang_encoder, FlamingoLMMixin) @@ -105,7 +109,8 @@ def set_input_embeddings(self, new_embeddings): model.vision_encoder.requires_grad_(False) model.lang_encoder.requires_grad_(False) - # Unfreeze gated_cross_attn_layers + # Unfreeze gated_cross_attn_layers and perceiver + model.perceiver.requires_grad_(True) model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) print( diff --git a/open_flamingo/src/flamingo_lm.py b/open_flamingo/src/flamingo_lm.py index 643da8a3..2d5b4186 100644 --- a/open_flamingo/src/flamingo_lm.py +++ b/open_flamingo/src/flamingo_lm.py @@ -122,12 +122,6 @@ def init_flamingo( ) self.get_input_embeddings().weight = input_embed_weights - # create a get_output_embeddings() / set_output_embeddings() method if it doesn't exist - # this is needed for compatibility - if not hasattr(self, "get_output_embeddings"): - self.get_output_embeddings = lambda: self.lm_head - self.set_output_embeddings = lambda x: setattr(self, "lm_head", x) - out_embeds = FlamingoDecoupledLinear( in_features=input_embed_weights.shape[1], out_features=input_embed_weights.shape[0], From 400fffc6e1d0c343dfee687bbba95104297494b6 Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Fri, 1 Sep 2023 04:17:17 +0000 Subject: [PATCH 07/24] fix bias check --- open_flamingo/src/flamingo_lm.py | 2 +- open_flamingo/train/train_utils.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/open_flamingo/src/flamingo_lm.py b/open_flamingo/src/flamingo_lm.py index 2d5b4186..5daedb8c 100644 --- a/open_flamingo/src/flamingo_lm.py +++ b/open_flamingo/src/flamingo_lm.py @@ -135,7 +135,7 @@ def init_flamingo( else: out_embeds.weight = self.get_output_embeddings().weight - if self.get_output_embeddings().bias is not None: + if getattr(self.get_output_embeddings(), "bias", None): out_embeds.bias = self.get_output_embeddings().bias self.set_output_embeddings(out_embeds) diff --git a/open_flamingo/train/train_utils.py b/open_flamingo/train/train_utils.py index 30f7398f..2be50a1a 100644 --- a/open_flamingo/train/train_utils.py +++ b/open_flamingo/train/train_utils.py @@ -72,9 +72,6 @@ def train_one_epoch( # setup model media_token_id = tokenizer("", add_special_tokens=False)["input_ids"][-1] - endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)[ - "input_ids" - ][-1] model.train() # setup logging From bb017177919c496acad3d7da425f2a3428745585 Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Fri, 1 Sep 2023 04:37:50 +0000 Subject: [PATCH 08/24] another bias fix --- open_flamingo/src/factory.py | 6 ++++-- open_flamingo/src/flamingo_lm.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 3a194eab..67e838e4 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -80,10 +80,12 @@ def set_input_embeddings(self, new_embeddings): self.transformer.wte = new_embeddings extend_instance(lang_encoder, EmbeddingFnMixin) - + if not hasattr(lang_encoder, "get_output_embeddings"): lang_encoder.get_output_embeddings = lambda: lang_encoder.lm_head - lang_encoder.set_output_embeddings = lambda x: setattr(lang_encoder, "lm_head", x) + lang_encoder.set_output_embeddings = lambda x: setattr( + lang_encoder, "lm_head", x + ) # convert LM to FlamingoLM extend_instance(lang_encoder, FlamingoLMMixin) diff --git a/open_flamingo/src/flamingo_lm.py b/open_flamingo/src/flamingo_lm.py index 5daedb8c..ff5c4f52 100644 --- a/open_flamingo/src/flamingo_lm.py +++ b/open_flamingo/src/flamingo_lm.py @@ -125,7 +125,7 @@ def init_flamingo( out_embeds = FlamingoDecoupledLinear( in_features=input_embed_weights.shape[1], out_features=input_embed_weights.shape[0], - bias=self.get_output_embeddings().bias is not None, + bias=getattr(self.get_output_embeddings(), "bias", None) is not None, out_additional_features=new_tokens, partially_freeze=True, ) From 104975ca424f6c7ccd3191526e3ec30ab00f93a4 Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Thu, 31 Aug 2023 22:02:09 -0700 Subject: [PATCH 09/24] remove trust_remote_code as mpt is part of transformers now --- open_flamingo/src/factory.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 67e838e4..eb1f9a55 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -48,7 +48,6 @@ def create_model_and_transforms( text_tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, local_files_only=use_local_files, - trust_remote_code=True, cache_dir=cache_dir, ) # add Flamingo special tokens to the tokenizer @@ -65,7 +64,6 @@ def create_model_and_transforms( lang_encoder = AutoModelForCausalLM.from_pretrained( lang_encoder_path, local_files_only=use_local_files, - trust_remote_code=True, cache_dir=cache_dir, ) From 17babfedc0035d60bebc3b6559e3bc48bd301dea Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Fri, 1 Sep 2023 20:38:11 -0700 Subject: [PATCH 10/24] add lm_head check --- open_flamingo/src/factory.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index eb1f9a55..5523448d 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -80,10 +80,22 @@ def set_input_embeddings(self, new_embeddings): extend_instance(lang_encoder, EmbeddingFnMixin) if not hasattr(lang_encoder, "get_output_embeddings"): - lang_encoder.get_output_embeddings = lambda: lang_encoder.lm_head - lang_encoder.set_output_embeddings = lambda x: setattr( - lang_encoder, "lm_head", x - ) + if hasattr(lang_encoder, "lm_head"): + lang_encoder.get_output_embeddings = lambda: lang_encoder.lm_head + else: + raise ValueError( + "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this string manually." + ) + + if not hasattr(lang_encoder, "set_output_embeddings"): + if hasattr(lang_encoder, "lm_head"): + lang_encoder.set_output_embeddings = lambda x: setattr( + lang_encoder, "lm_head", x + ) + else: + raise ValueError( + "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this string manually." + ) # convert LM to FlamingoLM extend_instance(lang_encoder, FlamingoLMMixin) From 10afa74d18e545488ff658048062b71995448757 Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Fri, 1 Sep 2023 22:19:12 -0700 Subject: [PATCH 11/24] more changes --- open_flamingo/src/factory.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 5523448d..62cdb00c 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -49,6 +49,7 @@ def create_model_and_transforms( tokenizer_path, local_files_only=use_local_files, cache_dir=cache_dir, + trust_remote_code=True ) # add Flamingo special tokens to the tokenizer text_tokenizer.add_special_tokens( @@ -65,6 +66,7 @@ def create_model_and_transforms( lang_encoder_path, local_files_only=use_local_files, cache_dir=cache_dir, + trust_remote_code=True ) # hacks for MPT-1B, which doesn't have a get_input_embeddings method @@ -84,7 +86,7 @@ def set_input_embeddings(self, new_embeddings): lang_encoder.get_output_embeddings = lambda: lang_encoder.lm_head else: raise ValueError( - "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this string manually." + "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py." ) if not hasattr(lang_encoder, "set_output_embeddings"): @@ -94,7 +96,7 @@ def set_input_embeddings(self, new_embeddings): ) else: raise ValueError( - "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this string manually." + "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py." ) # convert LM to FlamingoLM From a248c261de6a13ec3536de6a8bd4d164fd0be208 Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Sat, 2 Sep 2023 00:55:54 -0700 Subject: [PATCH 12/24] Update factory.py --- open_flamingo/src/factory.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 62cdb00c..61c0ef95 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -126,6 +126,11 @@ def set_input_embeddings(self, new_embeddings): # Unfreeze gated_cross_attn_layers and perceiver model.perceiver.requires_grad_(True) model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) + if hasattr(model.lang_encoder.get_output_embeddings(), "additional_fc"): + model.lang_encoder.get_output_embeddings().additional_fc.requires_grad_(True) + + if hasattr(model.lang_encoder.get_input_embeddings(), "additional_embedding"): + model.lang_encoder.get_input_embeddings().additional_embedding.requires_grad_(True) print( f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" From c453ca8e6a1b4fb3127b7806476dcbe5aa2143e3 Mon Sep 17 00:00:00 2001 From: Irena Gao Date: Mon, 4 Sep 2023 09:29:18 -0700 Subject: [PATCH 13/24] init on device to avoid cpu oom --- open_flamingo/eval/eval_model.py | 31 ++++++++-------- open_flamingo/eval/evaluate.py | 42 ++++++++++++---------- open_flamingo/eval/models/blip.py | 15 ++++---- open_flamingo/eval/models/idefics.py | 11 +++--- open_flamingo/eval/models/open_flamingo.py | 27 +++++++------- 5 files changed, 68 insertions(+), 58 deletions(-) diff --git a/open_flamingo/eval/eval_model.py b/open_flamingo/eval/eval_model.py index c011deea..97b5843d 100644 --- a/open_flamingo/eval/eval_model.py +++ b/open_flamingo/eval/eval_model.py @@ -4,8 +4,7 @@ from PIL import Image from utils import get_autocast, get_cast_dtype import torch -from transformers.modeling_outputs import CausalLMOutputWithPast - +from contextlib import suppress class BaseEvalModel(abc.ABC): """Base class encapsulating functionality needed to evaluate a model.""" @@ -19,17 +18,23 @@ def __init__(self, model_args: List[str]): is non-empty. """ - def __init__(self, model_args): + def __init__(self, model_args, init_on_device=False): assert "lm_path" in model_args, "All models require the lm_path argument" self.device = ( model_args["device"] - if ("device" in model_args and model_args["device"] >= 0) + if ("device" in model_args and (type(model_args["device"]) != int or model_args["device"] >= 0)) else "cpu" ) - precision = model_args.get("precision", "fp32") + self.precision = model_args.get("precision", "fp32") self.lm_name = model_args["lm_path"].split("/")[-1] - self.autocast = get_autocast(precision) - self.cast_dtype = get_cast_dtype(precision) + self.autocast = get_autocast(self.precision) + self.cast_dtype = get_cast_dtype(self.precision) + if init_on_device: + # for deepspeed, must init on device, or likely CPU OOM + import deepspeed + self.init_ctx = deepspeed.OnDevice(dtype=self.cast_dtype, device=self.device) + else: + self.init_ctx = suppress() def _check_init(self): """Finish model initialization.""" @@ -42,8 +47,8 @@ def _check_init(self): def init_distributed(self, world_size=None, use_deepspeed=False): """Wrap model as DDP or deepspeed.""" if use_deepspeed: + assert "amp" not in self.precision, "Deepspeed does not support amp" import deepspeed - self.ds_engine = deepspeed.init_inference( self.model, mp_size=world_size, @@ -52,13 +57,15 @@ def init_distributed(self, world_size=None, use_deepspeed=False): replace_with_kernel_inject=True, ) self.model = self.ds_engine.module + self.autocast = get_autocast(None) else: self.model = DDP(self.model, device_ids=[self.device]) def set_device(self, device): """Set device for model.""" - self.device = device - self.model = self.model.to(device) + torch.cuda.set_device(device) + self.device = torch.device("cuda", device) + self.model = self.model.to(device, dtype=self.cast_dtype) def __call__( self, @@ -120,10 +127,6 @@ def get_outputs( of any images to be included. batch_images: images to provide to model. Should be a list of lists, where each list contains the images for a single example. - max_generation_length: maximum length of the generated caption. - Defaults to 10. - num_beams: number of beams to use for beam search. Defaults to 3. - length_penalty: length penalty for beam search. Defaults to -2.0. Returns: List of decoded output strings. diff --git a/open_flamingo/eval/evaluate.py b/open_flamingo/eval/evaluate.py index 872668f3..daecbead 100644 --- a/open_flamingo/eval/evaluate.py +++ b/open_flamingo/eval/evaluate.py @@ -382,6 +382,12 @@ action="store_true", help="Use horovod for distributed training.", ) +parser.add_argument( + "--local_rank", + default=0, + type=int, + help="Rank of distributed process (default: 0). Usually overwritten by world_info_from_env()", +) parser.add_argument( "--no-set-device-rank", default=False, @@ -395,22 +401,20 @@ help="Whether to use deepspeed for distributed inference.", ) - def main(): args, leftovers = parser.parse_known_args() module = importlib.import_module(f"open_flamingo.eval.models.{args.model}") - model_args = { - leftovers[i].lstrip("-"): leftovers[i + 1] for i in range(0, len(leftovers), 2) - } - eval_model = module.EvalModel(model_args) - # set up distributed evaluation args.local_rank, args.rank, args.world_size = world_info_from_env() device_id = init_distributed_device(args) - eval_model.set_device(device_id) + model_args = { + leftovers[i].lstrip("-"): leftovers[i + 1] for i in range(0, len(leftovers), 2) + } + model_args['device'] = device_id + eval_model = module.EvalModel(model_args, init_on_device=args.deepspeed) eval_model.init_distributed( - world_size=args.world_size, use_deepspeed=args.deepspeed + local_rank=args.local_rank, world_size=args.world_size, use_deepspeed=args.deepspeed ) if args.model != "open_flamingo" and args.shots != [0]: @@ -626,7 +630,7 @@ def main(): num_shots=shot, seed=seed, dataset_name="textvqa", - max_generation_length=10, + max_new_tokens=10, cached_features=cached_features, ) if args.rank == 0: @@ -737,8 +741,8 @@ def evaluate_captioning( args: argparse.Namespace, eval_model: BaseEvalModel, seed: int = 42, - min_generation_length: int = 0, - max_generation_length: int = 20, + min_new_tokens: int = 0, + max_new_tokens: int = 20, num_beams: int = 3, length_penalty: float = 0.0, num_shots: int = 8, @@ -751,7 +755,7 @@ def evaluate_captioning( args (argparse.Namespace): arguments eval_model (BaseEvalModel): model to evaluate seed (int, optional): seed for random number generator. Defaults to 42. - max_generation_length (int, optional): maximum length of the generated caption. Defaults to 20. + max_new_tokens (int, optional): maximum length of the generated caption. Defaults to 20. num_beams (int, optional): number of beams to use for beam search. Defaults to 3. length_penalty (float, optional): length penalty for beam search. Defaults to -2.0. num_shots (int, optional): number of in-context samples to use. Defaults to 8. @@ -851,8 +855,8 @@ def evaluate_captioning( outputs = eval_model.get_outputs( batch_images=batch_images, batch_text=batch_text, - min_generation_length=min_generation_length, - max_generation_length=max_generation_length, + min_new_tokens=min_new_tokens, + max_new_tokens=max_new_tokens, num_beams=num_beams, length_penalty=length_penalty, ) @@ -908,8 +912,8 @@ def evaluate_vqa( args: argparse.Namespace, eval_model: BaseEvalModel, seed: int = 42, - min_generation_length: int = 0, - max_generation_length: int = 5, + min_new_tokens: int = 0, + max_new_tokens: int = 5, num_beams: int = 3, length_penalty: float = 0.0, num_shots: int = 8, @@ -923,7 +927,7 @@ def evaluate_vqa( args (argparse.Namespace): arguments eval_model (BaseEvalModel): model to evaluate seed (int, optional): random seed. Defaults to 42. - max_generation_length (int, optional): max generation length. Defaults to 5. + max_new_tokens (int, optional): max generation length. Defaults to 5. num_beams (int, optional): number of beams to use for beam search. Defaults to 3. length_penalty (float, optional): length penalty for beam search. Defaults to -2.0. num_shots (int, optional): number of shots to use. Defaults to 8. @@ -1044,8 +1048,8 @@ def evaluate_vqa( outputs = eval_model.get_outputs( batch_images=batch_images, batch_text=batch_text, - min_generation_length=min_generation_length, - max_generation_length=max_generation_length, + min_new_tokens=min_new_tokens, + max_new_tokens=max_new_tokens, num_beams=num_beams, length_penalty=length_penalty, ) diff --git a/open_flamingo/eval/models/blip.py b/open_flamingo/eval/models/blip.py index 82cbdd14..1ae3e0a5 100644 --- a/open_flamingo/eval/models/blip.py +++ b/open_flamingo/eval/models/blip.py @@ -12,16 +12,17 @@ class EvalModel(BaseEvalModel): """BLIP-2 model evaluation.""" - def __init__(self, **model_args): + def __init__(self, model_args, init_on_device=False): assert ( "processor_path" in model_args and "lm_path" in model_args ), "BLIP-2 requires processor_path, lm_path, and device arguments to be specified" - super().__init__(model_args) - self.processor = Blip2Processor.from_pretrained(model_args["processor_path"]) - self.model = Blip2ForConditionalGeneration.from_pretrained( - model_args["lm_path"] - ) - self.tokenizer = self.processor.tokenizer + super().__init__(model_args, init_on_device) + with self.init_ctx: + self.processor = Blip2Processor.from_pretrained(model_args["processor_path"]) + self.model = Blip2ForConditionalGeneration.from_pretrained( + model_args["lm_path"] + ) + self.tokenizer = self.processor.tokenizer self._check_init() def prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor: diff --git a/open_flamingo/eval/models/idefics.py b/open_flamingo/eval/models/idefics.py index 8707eb91..f8c50fba 100644 --- a/open_flamingo/eval/models/idefics.py +++ b/open_flamingo/eval/models/idefics.py @@ -17,14 +17,15 @@ class EvalModel(BaseEvalModel): """IDEFICS model evaluation.""" - def __init__(self, **model_args): + def __init__(self, model_args, init_on_device=False): assert ( "lm_path" in model_args and "processor_path" in model_args ), "IDEFICS requires lm_path and lm_tokenizer_path" - super().__init__(model_args) - self.model = IdeficsForVisionText2Text.from_pretrained(model_args["lm_path"]) - self.processor = AutoProcessor.from_pretrained(model_args["processor_path"]) - self.tokenizer = self.processor.tokenizer + super().__init__(model_args, init_on_device) + with self.init_ctx: + self.model = IdeficsForVisionText2Text.from_pretrained(model_args["lm_path"]) + self.processor = AutoProcessor.from_pretrained(model_args["processor_path"]) + self.tokenizer = self.processor.tokenizer self._check_init() def prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor: diff --git a/open_flamingo/eval/models/open_flamingo.py b/open_flamingo/eval/models/open_flamingo.py index 0fe848ea..624d96f2 100644 --- a/open_flamingo/eval/models/open_flamingo.py +++ b/open_flamingo/eval/models/open_flamingo.py @@ -13,7 +13,7 @@ class EvalModel(BaseEvalModel): """OpenFlamingo model evaluation.""" - def __init__(self, model_args): + def __init__(self, model_args, init_on_device=False): assert ( "vision_encoder_path" in model_args and "lm_path" in model_args @@ -22,18 +22,19 @@ def __init__(self, model_args): and "cross_attn_every_n_layers" in model_args and "vision_encoder_pretrained" in model_args ), "OpenFlamingo requires vision_encoder_path, lm_path, device, checkpoint_path, lm_tokenizer_path, cross_attn_every_n_layers, vision_encoder_pretrained arguments to be specified" - super().__init__(model_args) - ( - self.model, - self.image_processor, - self.tokenizer, - ) = create_model_and_transforms( - model_args["vision_encoder_path"], - model_args["vision_encoder_pretrained"], - model_args["lm_path"], - model_args["lm_tokenizer_path"], - cross_attn_every_n_layers=int(model_args["cross_attn_every_n_layers"]), - ) + super().__init__(model_args, init_on_device) + with self.init_ctx: + ( + self.model, + self.image_processor, + self.tokenizer, + ) = create_model_and_transforms( + model_args["vision_encoder_path"], + model_args["vision_encoder_pretrained"], + model_args["lm_path"], + model_args["lm_tokenizer_path"], + cross_attn_every_n_layers=int(model_args["cross_attn_every_n_layers"]), + ) checkpoint = torch.load(model_args["checkpoint_path"], map_location="cpu") if "model_state_dict" in checkpoint: checkpoint = checkpoint["model_state_dict"] From de187aa27befaba9a1567da8ca0727e25641b166 Mon Sep 17 00:00:00 2001 From: Irena Gao Date: Mon, 4 Sep 2023 10:28:18 -0700 Subject: [PATCH 14/24] update eval script to use deepspeed --- open_flamingo/scripts/run_eval.sh | 5 +++-- open_flamingo/src/factory.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/open_flamingo/scripts/run_eval.sh b/open_flamingo/scripts/run_eval.sh index d347d0a4..db869b09 100644 --- a/open_flamingo/scripts/run_eval.sh +++ b/open_flamingo/scripts/run_eval.sh @@ -22,7 +22,8 @@ echo go $COUNT_NODE echo $HOSTNAMES export PYTHONPATH="$PYTHONPATH:open_flamingo" -srun --cpu_bind=v --accel-bind=gn python open_flamingo/open_flamingo/eval/evaluate.py \ +srun --cpu_bind=v --accel-bind=gn python +deepspeed open_flamingo/open_flamingo/eval/evaluate.py \ --vision_encoder_path ViT-L-14 \ --vision_encoder_pretrained openai\ --lm_path anas-awadalla/mpt-1b-redpajama-200b \ @@ -30,7 +31,7 @@ srun --cpu_bind=v --accel-bind=gn python open_flamingo/open_flamingo/eval/evalua --cross_attn_every_n_layers 1 \ --checkpoint_path "openflamingo/OpenFlamingo-3B-vitl-mpt1b/checkpoint.pt" \ --results_file "results.json" \ - --precision amp_bf16 \ + --precision fp32 \ --batch_size 8 \ --eval_coco \ --eval_vqav2 \ diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 5dd4a4f2..15d60d52 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -88,7 +88,7 @@ def set_input_embeddings(self, new_embeddings): decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) lang_encoder.resize_token_embeddings( - len(text_tokenizer), pad_to_multiple_of=8 + len(text_tokenizer), ) # padding to enable tensor cores model = Flamingo( From 6f620541a81f47d1085aadaf6b6de00554029ee9 Mon Sep 17 00:00:00 2001 From: Irena Gao Date: Mon, 4 Sep 2023 10:30:24 -0700 Subject: [PATCH 15/24] restore pad_to_multiple_of kwarg in factory --- open_flamingo/src/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 15d60d52..0fd4ec42 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -88,7 +88,7 @@ def set_input_embeddings(self, new_embeddings): decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) lang_encoder.resize_token_embeddings( - len(text_tokenizer), + len(text_tokenizer), pad_to_multiple_of=8, ) # padding to enable tensor cores model = Flamingo( From e626afb16078841bde285c32c940647a636c57e4 Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Mon, 4 Sep 2023 18:21:56 -0700 Subject: [PATCH 16/24] fixed embed not training issue --- open_flamingo/src/factory.py | 17 +++++++++++++++-- open_flamingo/src/flamingo.py | 7 +++++-- open_flamingo/src/flamingo_lm.py | 5 +++-- open_flamingo/src/helpers.py | 4 ++-- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 62cdb00c..242d896d 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -61,7 +61,12 @@ def create_model_and_transforms( # modify labels for the loss. text_tokenizer.add_special_tokens({"pad_token": ""}) new_tokens += 1 - + + ids_for_additional_special_tokens = text_tokenizer.convert_tokens_to_ids( + ["<|endofchunk|>", "", ""] if new_tokens == 3 else ["<|endofchunk|>", ""] + ) + print(f"Added {new_tokens} new tokens to the tokenizer") + lang_encoder = AutoModelForCausalLM.from_pretrained( lang_encoder_path, local_files_only=use_local_files, @@ -105,7 +110,7 @@ def set_input_embeddings(self, new_embeddings): if decoder_layers_attr_name is None: decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) - + model = Flamingo( vision_encoder, lang_encoder, @@ -115,6 +120,8 @@ def set_input_embeddings(self, new_embeddings): "width" ], cross_attn_every_n_layers=cross_attn_every_n_layers, + # HACK: The tokenizer's size and model's vocab size sometimes don't match. We use this to find the smaller of the flamingo special tokens and use that as the vocab size (even though the true one might be smaller). + vocab_size=min(ids_for_additional_special_tokens), new_tokens=new_tokens, # number of tokens embeddings to train **flamingo_kwargs, ) @@ -127,6 +134,12 @@ def set_input_embeddings(self, new_embeddings): model.perceiver.requires_grad_(True) model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) + if hasattr(model.lang_encoder.get_output_embeddings(), "additional_fc"): + model.lang_encoder.get_output_embeddings().additional_fc.requires_grad_(True) + + if hasattr(model.lang_encoder.get_input_embeddings(), "additional_embedding"): + model.lang_encoder.get_input_embeddings().additional_embedding.requires_grad_(True) + print( f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" ) diff --git a/open_flamingo/src/flamingo.py b/open_flamingo/src/flamingo.py index ee02b65f..28b0304a 100644 --- a/open_flamingo/src/flamingo.py +++ b/open_flamingo/src/flamingo.py @@ -22,9 +22,10 @@ def __init__( eoc_token_id: int, media_token_id: int, vis_dim: int, + vocab_size: int, + new_tokens: int, cross_attn_every_n_layers: int = 1, gradient_checkpointing: bool = False, - new_tokens: int = 2, ): """ Args: @@ -34,9 +35,10 @@ def __init__( media_token_id (int): Token id for vis_dim (int): Dimension of the visual features. Visual features are projected to match this shape along the last dimension. + vocab_size (int): Size of the base vocabulary. + new_tokens (int): Number of new tokens added to the tokenizer. cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False. - new_tokens (int, optional): Number of new tokens added to the tokenizer. Defaults to 2. """ super().__init__() self.eoc_token_id = eoc_token_id @@ -56,6 +58,7 @@ def __init__( vis_hidden_size=self.vis_dim, cross_attn_every_n_layers=cross_attn_every_n_layers, gradient_checkpointing=gradient_checkpointing, + vocab_size=vocab_size, new_tokens=new_tokens, ) self._use_gradient_checkpointing = gradient_checkpointing diff --git a/open_flamingo/src/flamingo_lm.py b/open_flamingo/src/flamingo_lm.py index ff5c4f52..aa82c81d 100644 --- a/open_flamingo/src/flamingo_lm.py +++ b/open_flamingo/src/flamingo_lm.py @@ -91,6 +91,7 @@ def init_flamingo( vis_hidden_size, cross_attn_every_n_layers, gradient_checkpointing, + vocab_size, new_tokens, ): """ @@ -114,7 +115,7 @@ def init_flamingo( input_embed_weights = self.get_input_embeddings().weight self.set_input_embeddings( FlamingoDecoupledEmbedding( - num_embeddings=input_embed_weights.shape[0], + num_embeddings=vocab_size, num_additional_embeddings=new_tokens, embedding_dim=input_embed_weights.shape[1], partially_freeze=True, @@ -124,7 +125,7 @@ def init_flamingo( out_embeds = FlamingoDecoupledLinear( in_features=input_embed_weights.shape[1], - out_features=input_embed_weights.shape[0], + out_features=vocab_size, bias=getattr(self.get_output_embeddings(), "bias", None) is not None, out_additional_features=new_tokens, partially_freeze=True, diff --git a/open_flamingo/src/helpers.py b/open_flamingo/src/helpers.py index c7adf303..43b0f429 100644 --- a/open_flamingo/src/helpers.py +++ b/open_flamingo/src/helpers.py @@ -297,7 +297,7 @@ def __init__( num_embeddings, num_additional_embeddings, embedding_dim, - partially_freeze=False, + partially_freeze=True, device=None, dtype=None, padding_idx=None, @@ -311,7 +311,7 @@ def __init__( Number of additional embeddings. Only useful when you `partially_freeze=True`. embedding_dim (`int`): The size of each embedding vector - partially_freeze: (`bool`, *optional*, defaults to `False`): + partially_freeze: (`bool`, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen. padding_idx (`int`, *optional*): The padding index (needs to be less than num_embeddings) From ecc74ad2567fca6ed35db0f6e943f2b239c0129f Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Tue, 5 Sep 2023 01:23:34 +0000 Subject: [PATCH 17/24] embed training --- open_flamingo/src/factory.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index eb1f9a55..0ea939f7 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -112,6 +112,12 @@ def set_input_embeddings(self, new_embeddings): # Unfreeze gated_cross_attn_layers and perceiver model.perceiver.requires_grad_(True) model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) + + if hasattr(model.lang_encoder.get_output_embeddings(), "additional_fc"): + model.lang_encoder.get_output_embeddings().additional_fc.requires_grad_(True) + + if hasattr(model.lang_encoder.get_input_embeddings(), "additional_embedding"): + model.lang_encoder.get_input_embeddings().additional_embedding.requires_grad_(True) print( f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" From fe39d1a286ba6d90973fb0d748172a5ef1c58250 Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Mon, 4 Sep 2023 18:56:54 -0700 Subject: [PATCH 18/24] tie decoupled embeddings --- open_flamingo/src/flamingo_lm.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/open_flamingo/src/flamingo_lm.py b/open_flamingo/src/flamingo_lm.py index aa82c81d..444baf81 100644 --- a/open_flamingo/src/flamingo_lm.py +++ b/open_flamingo/src/flamingo_lm.py @@ -131,16 +131,19 @@ def init_flamingo( partially_freeze=True, ) + if getattr(self.get_output_embeddings(), "bias", None): + out_embeds.bias = self.get_output_embeddings().bias + + self.set_output_embeddings(out_embeds) + if getattr(self.config, "tie_word_embeddings", True): out_embeds.weight = input_embed_weights + if self.get_input_embeddings().num_additional_embeddings > 0: + assert self.get_output_embeddings().out_additional_features == self.get_input_embeddings().num_additional_embeddings + self.get_output_embeddings().additional_fc.weight = self.get_input_embeddings().additional_embedding.weight else: out_embeds.weight = self.get_output_embeddings().weight - if getattr(self.get_output_embeddings(), "bias", None): - out_embeds.bias = self.get_output_embeddings().bias - - self.set_output_embeddings(out_embeds) - self.initialized_flamingo = True self._use_cached_vision_x = False From 8bd72738a71ee1cc955e76e271e5b1d6389b998d Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Mon, 4 Sep 2023 19:24:28 -0700 Subject: [PATCH 19/24] untie embeds for fsdp --- open_flamingo/src/factory.py | 17 ++++++++++++++++- open_flamingo/train/train.py | 5 +++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index c1a9bfe1..c62686c4 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -2,6 +2,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer import open_clip +import torch.nn as nn from .flamingo import Flamingo from .flamingo_lm import FlamingoLMMixin @@ -14,6 +15,7 @@ def create_model_and_transforms( lang_encoder_path: str, tokenizer_path: str, cross_attn_every_n_layers: int = 1, + untie_embeddings: bool = True, use_local_files: bool = False, decoder_layers_attr_name: str = None, cache_dir: Optional[str] = None, @@ -29,6 +31,7 @@ def create_model_and_transforms( lang_encoder_path (str): path to pretrained language encoder tokenizer_path (str): path to pretrained tokenizer cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. + untie_embeddings (bool, optional): whether to untie the input and output embeddings if they are tied. This is required for training using FSDP. Defaults to False. use_local_files (bool, optional): whether to use local files. Defaults to False. decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights. @@ -103,7 +106,19 @@ def set_input_embeddings(self, new_embeddings): raise ValueError( "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py." ) - + + if untie_embeddings: + new_output_embeddings_weight = lang_encoder.get_output_embeddings().weight.clone() + if lang_encoder.get_output_embeddings().bias is not None: + new_output_embeddings_bias = lang_encoder.get_output_embeddings().bias.clone() + else: + new_output_embeddings_bias = None + lang_encoder.get_output_embeddings().weight = nn.Parameter(new_output_embeddings_weight) + if new_output_embeddings_bias is not None: + lang_encoder.get_output_embeddings().bias = nn.Parameter(new_output_embeddings_bias) + + lang_encoder.config.update({"tie_word_embeddings": False}) + # convert LM to FlamingoLM extend_instance(lang_encoder, FlamingoLMMixin) diff --git a/open_flamingo/train/train.py b/open_flamingo/train/train.py index 3418efbb..0a403b2c 100644 --- a/open_flamingo/train/train.py +++ b/open_flamingo/train/train.py @@ -314,6 +314,9 @@ def main(): device_id = init_distributed_device(args) random_seed(args.seed) + + if args.fsdp: + print("Untying embeddings for FSDP") # Initialize model model, image_processor, tokenizer = create_model_and_transforms( @@ -322,6 +325,7 @@ def main(): args.lm_path, args.tokenizer_path if args.tokenizer_path else args.lm_path, cross_attn_every_n_layers=args.cross_attn_every_n_layers, + untie_embeddings=args.fsdp, # untie embeddings for FSDP use_local_files=args.offline, gradient_checkpointing=args.gradient_checkpointing, ) @@ -436,6 +440,7 @@ def main(): params_to_optimize = ( ddp_model.named_parameters() if not args.deepspeed else model.named_parameters() ) + params_to_optimize = list( filter( lambda x: x[1].requires_grad From d10a9981dfec6122384625da9c9ab8a78a1ca435 Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Mon, 4 Sep 2023 19:29:19 -0700 Subject: [PATCH 20/24] move grad checkpointing before optimizer creation --- open_flamingo/train/train.py | 38 ++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/open_flamingo/train/train.py b/open_flamingo/train/train.py index 0a403b2c..ef9d243b 100644 --- a/open_flamingo/train/train.py +++ b/open_flamingo/train/train.py @@ -435,6 +435,25 @@ def main(): elif not args.deepspeed: model = model.to(device_id) ddp_model = DDP(model, device_ids=[device_id]) + + # Initialize gradient checkpointing + if args.gradient_checkpointing: + if args.deepspeed: + raise ValueError( + "gradient checkpointing currently not supported with deepspeed" + ) + non_reentrant_wrapper = functools.partial( + checkpoint_wrapper, + offload_to_cpu=True, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + apply_activation_checkpointing( + ddp_model, + checkpoint_wrapper_fn=non_reentrant_wrapper, + check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False) + and not isinstance(m, FSDP) + and not isinstance(m, CheckpointWrapper), + ) # Initialize optimizer params_to_optimize = ( @@ -542,25 +561,6 @@ def get_grouped_params(model): checkpoint_epoch = int(f.read().split("_")[-1]) resume_from_epoch = checkpoint_epoch + 1 - # Initialize gradient checkpointing - if args.gradient_checkpointing: - if args.deepspeed: - raise ValueError( - "gradient checkpointing currently not supported with deepspeed" - ) - non_reentrant_wrapper = functools.partial( - checkpoint_wrapper, - offload_to_cpu=True, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, - ) - apply_activation_checkpointing( - ddp_model, - checkpoint_wrapper_fn=non_reentrant_wrapper, - check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False) - and not isinstance(m, FSDP) - and not isinstance(m, CheckpointWrapper), - ) - for epoch in range(resume_from_epoch, args.num_epochs): laion_dataset.set_epoch(epoch) laion_loader = laion_dataset.dataloader From 4f3ce24b59f00bbfbc0f58fe5d054891c0991836 Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Mon, 4 Sep 2023 19:30:06 -0700 Subject: [PATCH 21/24] default is not to untie --- open_flamingo/src/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index c62686c4..987da0c6 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -15,7 +15,7 @@ def create_model_and_transforms( lang_encoder_path: str, tokenizer_path: str, cross_attn_every_n_layers: int = 1, - untie_embeddings: bool = True, + untie_embeddings: bool = False, use_local_files: bool = False, decoder_layers_attr_name: str = None, cache_dir: Optional[str] = None, From 2c5d86450e0946493c8c5ffdbb5bcda43375a80b Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Mon, 4 Sep 2023 22:14:04 -0700 Subject: [PATCH 22/24] fix embed init --- open_flamingo/src/flamingo_lm.py | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/open_flamingo/src/flamingo_lm.py b/open_flamingo/src/flamingo_lm.py index 444baf81..16cb04bf 100644 --- a/open_flamingo/src/flamingo_lm.py +++ b/open_flamingo/src/flamingo_lm.py @@ -112,38 +112,29 @@ def init_flamingo( self.media_token_id = media_token_id # modify the embedding layer to support decoupling - input_embed_weights = self.get_input_embeddings().weight - self.set_input_embeddings( - FlamingoDecoupledEmbedding( - num_embeddings=vocab_size, - num_additional_embeddings=new_tokens, - embedding_dim=input_embed_weights.shape[1], - partially_freeze=True, - ) + input_embeds = FlamingoDecoupledEmbedding( + num_embeddings=vocab_size, + num_additional_embeddings=new_tokens, + embedding_dim=self.get_input_embeddings().weight.shape[1], + partially_freeze=True, ) - self.get_input_embeddings().weight = input_embed_weights + input_embeds.weight = self.get_input_embeddings().weight + self.set_input_embeddings(input_embeds) out_embeds = FlamingoDecoupledLinear( - in_features=input_embed_weights.shape[1], + in_features=self.get_input_embeddings().weight.shape[1], out_features=vocab_size, - bias=getattr(self.get_output_embeddings(), "bias", None) is not None, + bias=self.get_output_embeddings().bias is not None, out_additional_features=new_tokens, partially_freeze=True, ) - if getattr(self.get_output_embeddings(), "bias", None): + if self.get_output_embeddings().bias is not None: out_embeds.bias = self.get_output_embeddings().bias + out_embeds.weight = self.get_output_embeddings().weight self.set_output_embeddings(out_embeds) - if getattr(self.config, "tie_word_embeddings", True): - out_embeds.weight = input_embed_weights - if self.get_input_embeddings().num_additional_embeddings > 0: - assert self.get_output_embeddings().out_additional_features == self.get_input_embeddings().num_additional_embeddings - self.get_output_embeddings().additional_fc.weight = self.get_input_embeddings().additional_embedding.weight - else: - out_embeds.weight = self.get_output_embeddings().weight - self.initialized_flamingo = True self._use_cached_vision_x = False From 8ff8d5edd1f39f8cbd8109ad653b2f0b4f7ca6ac Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Tue, 5 Sep 2023 12:57:11 -0700 Subject: [PATCH 23/24] Update factory.py --- open_flamingo/src/factory.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 987da0c6..d9f86ec9 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -77,6 +77,9 @@ def create_model_and_transforms( trust_remote_code=True ) + # change model's vocab size to include new tokens + lang_encoder.config.vocab_size = len(text_tokenizer) + # hacks for MPT-1B, which doesn't have a get_input_embeddings method if "mpt-1b-redpajama-200b" in lang_encoder_path: From d75fecd118eaefbaf13806bd69aea12c11df8f6b Mon Sep 17 00:00:00 2001 From: Anas Awadalla Date: Fri, 8 Sep 2023 17:35:55 -0700 Subject: [PATCH 24/24] fix embed init and out embed concat --- open_flamingo/src/factory.py | 34 +++++++++++++---------------- open_flamingo/src/flamingo.py | 37 ++++++++++++++++++++++++-------- open_flamingo/src/flamingo_lm.py | 20 +++++++++++------ open_flamingo/src/helpers.py | 11 ++++++---- 4 files changed, 63 insertions(+), 39 deletions(-) diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 987da0c6..c8cb8fae 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -52,30 +52,35 @@ def create_model_and_transforms( tokenizer_path, local_files_only=use_local_files, cache_dir=cache_dir, - trust_remote_code=True + # trust_remote_code=True ) + # add Flamingo special tokens to the tokenizer text_tokenizer.add_special_tokens( {"additional_special_tokens": ["<|endofchunk|>", ""]} ) new_tokens = 2 - if text_tokenizer.pad_token is None: + if text_tokenizer.pad_token is None and text_tokenizer.pad_token_id is None: # need to check both because some tokenizers have a pad token id but not a pad token # Issue: GPT models don't have a pad token, which we use to # modify labels for the loss. - text_tokenizer.add_special_tokens({"pad_token": ""}) - new_tokens += 1 + text_tokenizer.pad_token_id = text_tokenizer.eos_token_id + + # text_tokenizer.add_special_tokens({"pad_token": ""}) + # new_tokens += 1 ids_for_additional_special_tokens = text_tokenizer.convert_tokens_to_ids( - ["<|endofchunk|>", "", ""] if new_tokens == 3 else ["<|endofchunk|>", ""] + ["<|endofchunk|>","",""] if new_tokens == 3 else ["<|endofchunk|>", ""] ) - print(f"Added {new_tokens} new tokens to the tokenizer") lang_encoder = AutoModelForCausalLM.from_pretrained( lang_encoder_path, local_files_only=use_local_files, cache_dir=cache_dir, - trust_remote_code=True + # trust_remote_code=True ) + + lang_encoder.config.update({"original_vocab_size": min(ids_for_additional_special_tokens)}) + lang_encoder.config.vocab_size = max(len(text_tokenizer), lang_encoder.config.vocab_size) # hacks for MPT-1B, which doesn't have a get_input_embeddings method if "mpt-1b-redpajama-200b" in lang_encoder_path: @@ -108,15 +113,7 @@ def set_input_embeddings(self, new_embeddings): ) if untie_embeddings: - new_output_embeddings_weight = lang_encoder.get_output_embeddings().weight.clone() - if lang_encoder.get_output_embeddings().bias is not None: - new_output_embeddings_bias = lang_encoder.get_output_embeddings().bias.clone() - else: - new_output_embeddings_bias = None - lang_encoder.get_output_embeddings().weight = nn.Parameter(new_output_embeddings_weight) - if new_output_embeddings_bias is not None: - lang_encoder.get_output_embeddings().bias = nn.Parameter(new_output_embeddings_bias) - + lang_encoder.get_output_embeddings().weight = nn.Parameter(lang_encoder.get_output_embeddings().weight.clone()) lang_encoder.config.update({"tie_word_embeddings": False}) # convert LM to FlamingoLM @@ -125,7 +122,7 @@ def set_input_embeddings(self, new_embeddings): if decoder_layers_attr_name is None: decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) - + model = Flamingo( vision_encoder, lang_encoder, @@ -135,9 +132,8 @@ def set_input_embeddings(self, new_embeddings): "width" ], cross_attn_every_n_layers=cross_attn_every_n_layers, - # HACK: The tokenizer's size and model's vocab size sometimes don't match. We use this to find the smaller of the flamingo special tokens and use that as the vocab size (even though the true one might be smaller). - vocab_size=min(ids_for_additional_special_tokens), new_tokens=new_tokens, # number of tokens embeddings to train + padding_token_id=text_tokenizer.pad_token_id, **flamingo_kwargs, ) diff --git a/open_flamingo/src/flamingo.py b/open_flamingo/src/flamingo.py index 28b0304a..61f0b6da 100644 --- a/open_flamingo/src/flamingo.py +++ b/open_flamingo/src/flamingo.py @@ -21,8 +21,9 @@ def __init__( lang_encoder: nn.Module, eoc_token_id: int, media_token_id: int, + padding_token_id: int, vis_dim: int, - vocab_size: int, + # vocab_size: int, new_tokens: int, cross_attn_every_n_layers: int = 1, gradient_checkpointing: bool = False, @@ -33,9 +34,9 @@ def __init__( lang_encoder (nn.Module): HF causal language model eoc_token_id (int): Token id for <|endofchunk|> media_token_id (int): Token id for + padding_token_id (int): Token id for padding token vis_dim (int): Dimension of the visual features. Visual features are projected to match this shape along the last dimension. - vocab_size (int): Size of the base vocabulary. new_tokens (int): Number of new tokens added to the tokenizer. cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False. @@ -54,11 +55,11 @@ def __init__( self.lang_encoder = lang_encoder self.lang_encoder.init_flamingo( media_token_id=media_token_id, + padding_token_id=padding_token_id, lang_hidden_size=self.lang_dim, vis_hidden_size=self.vis_dim, cross_attn_every_n_layers=cross_attn_every_n_layers, gradient_checkpointing=gradient_checkpointing, - vocab_size=vocab_size, new_tokens=new_tokens, ) self._use_gradient_checkpointing = gradient_checkpointing @@ -275,12 +276,30 @@ def wrap_fsdp(self, wrapper_kwargs, device_id): for layer in self.lang_encoder.gated_cross_attn_layers ) self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing) - self.lang_encoder.set_input_embeddings( - wrap(wrap(self.lang_encoder.get_input_embeddings())) - ) - self.lang_encoder.set_output_embeddings( - wrap(wrap(self.lang_encoder.get_output_embeddings())) - ) + if hasattr(self.lang_encoder.get_input_embeddings(), "additional_embedding"): + # wrap additional_embedding and original embedding separately, this is the case when using decoupled embeddings + self.lang_encoder.get_input_embeddings().additional_embedding = wrap( + wrap(self.lang_encoder.get_input_embeddings().additional_embedding) + ) + self.lang_encoder.get_input_embeddings().weight = wrap(wrap(self.lang_encoder.get_input_embeddings().weight)) + else: + self.lang_encoder.set_input_embeddings( + wrap(wrap(self.lang_encoder.get_input_embeddings())) + ) + + if hasattr(self.lang_encoder.get_output_embeddings(), "additional_fc"): + # wrap additional_fc and original embedding separately, this is the case when using decoupled linear layer + self.lang_encoder.get_output_embeddings().additional_fc = wrap( + wrap(self.lang_encoder.get_output_embeddings().additional_fc) + ) + self.lang_encoder.get_output_embeddings().weight = wrap(wrap(self.lang_encoder.get_output_embeddings().weight)) + if self.lang_encoder.get_output_embeddings().bias is not None: + self.lang_encoder.get_output_embeddings().bias = wrap(wrap(self.lang_encoder.get_output_embeddings().bias)) + else: + self.lang_encoder.set_output_embeddings( + wrap(wrap(self.lang_encoder.get_output_embeddings())) + ) + self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen # manually move non-FSDP managed parameters to device_id diff --git a/open_flamingo/src/flamingo_lm.py b/open_flamingo/src/flamingo_lm.py index 16cb04bf..6c9545b3 100644 --- a/open_flamingo/src/flamingo_lm.py +++ b/open_flamingo/src/flamingo_lm.py @@ -87,11 +87,11 @@ def _set_decoder_layers(self, value): def init_flamingo( self, media_token_id, + padding_token_id, lang_hidden_size, vis_hidden_size, cross_attn_every_n_layers, gradient_checkpointing, - vocab_size, new_tokens, ): """ @@ -113,17 +113,19 @@ def init_flamingo( # modify the embedding layer to support decoupling input_embeds = FlamingoDecoupledEmbedding( - num_embeddings=vocab_size, + num_embeddings=self.config.original_vocab_size, num_additional_embeddings=new_tokens, - embedding_dim=self.get_input_embeddings().weight.shape[1], + embedding_dim=self.config.hidden_size, partially_freeze=True, + padding_idx=padding_token_id, ) input_embeds.weight = self.get_input_embeddings().weight + input_embeds.additional_embedding.weight.data.normal_(mean=0.0, std=self.config.initializer_range) self.set_input_embeddings(input_embeds) out_embeds = FlamingoDecoupledLinear( - in_features=self.get_input_embeddings().weight.shape[1], - out_features=vocab_size, + in_features=self.config.hidden_size, + out_features=self.config.original_vocab_size, bias=self.get_output_embeddings().bias is not None, out_additional_features=new_tokens, partially_freeze=True, @@ -132,9 +134,13 @@ def init_flamingo( if self.get_output_embeddings().bias is not None: out_embeds.bias = self.get_output_embeddings().bias - out_embeds.weight = self.get_output_embeddings().weight + out_embeds.weight = self.get_output_embeddings().weight + out_embeds.additional_fc.weight.data.normal_(mean=0.0, std=self.config.initializer_range) self.set_output_embeddings(out_embeds) - + + if getattr(self.config, "tie_word_embeddings", False): + self.get_output_embeddings().additional_fc.weight = self.get_input_embeddings().additional_embedding.weight + self.initialized_flamingo = True self._use_cached_vision_x = False diff --git a/open_flamingo/src/helpers.py b/open_flamingo/src/helpers.py index 43b0f429..dd44737b 100644 --- a/open_flamingo/src/helpers.py +++ b/open_flamingo/src/helpers.py @@ -281,8 +281,6 @@ def forward( # Both FlamingoDecoupledEmbedding and FlamingoDecoupledLinear are taken from https://github.com/huggingface/transformers/blob/v4.32.1/src/transformers/models/idefics/modeling_idefics.py and renamed for clarity - - class FlamingoDecoupledEmbedding(nn.Embedding): # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding """ @@ -382,7 +380,7 @@ def forward(self, input_ids): # for successful lookup replace input_ids with 0, the results of these will be discarded anyway input_ids[additional_vocab_indices] = 0 full_vector = F.embedding(input_ids, self.weight) - + # overwrite the records with high indices full_vector[additional_vocab_indices] = additional_embeddings @@ -449,7 +447,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: additional_features = F.linear( input, self.additional_fc.weight, self.additional_fc.bias ) - output = torch.cat((output, additional_features), -1) + # Concatenate the additional features to the output if new vocab doesn't have a placeholder token in the original embedding + if self.weight.shape[0] < self.out_features + self.out_additional_features: + output = torch.cat((output, additional_features), dim=-1) + else: + # Otherwise, overwrite the placeholder tokens with the additional features + output[..., self.out_features:self.out_features + self.out_additional_features] = additional_features return output