From 0b1c9263d5ecd123de8ae3e89f611713452116cc Mon Sep 17 00:00:00 2001 From: root Date: Sat, 24 Feb 2024 20:36:36 +0000 Subject: [PATCH] remove z-loss mess --- open_flamingo/train/losses.py | 83 +++-------------------------------- 1 file changed, 6 insertions(+), 77 deletions(-) diff --git a/open_flamingo/train/losses.py b/open_flamingo/train/losses.py index 0f86e3a4..f57727b4 100644 --- a/open_flamingo/train/losses.py +++ b/open_flamingo/train/losses.py @@ -1,16 +1,12 @@ from open_flamingo.src.vlm import VLM import torch -from torch import Tensor -from torch.nn import CrossEntropyLoss -SUPPORTED_LOSSES = ["next_token_prediction", "next_token_prediction_with_z_loss"] +SUPPORTED_LOSSES = ["next_token_prediction"] def get_loss_fn(loss_name): if loss_name == "next_token_prediction": return NextTokenPrediction() - elif loss_name == "next_token_prediction_with_z_loss": - return NextTokenPredictionWithZLoss() else: raise ValueError( f"Loss {loss_name} not supported. Supported losses: {SUPPORTED_LOSSES}" @@ -47,10 +43,10 @@ def __call__( raise NotImplementedError -class NextTokenPredictionWithZLoss(Loss): +class NextTokenPrediction(Loss): @property def name(self): - return "next_token_prediction_with_z_loss" + return "next_token_prediction" def __call__( self, @@ -60,7 +56,6 @@ def __call__( input_ids: torch.Tensor, attention_mask: torch.Tensor, autocast: callable, - z_loss_eps: float = 1e-4, ): # set up labels; language model is expected to handle shifting labels = input_ids.clone() @@ -74,55 +69,15 @@ def __call__( # call forward with autocast(): - logits = model( + loss = model( vision_x=images, lang_x=input_ids, attention_mask=attention_mask, labels=labels, - )[1] - - logits = logits.float() - - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLossWithZLoss(eps=z_loss_eps) - shift_logits = shift_logits.view(-1, unwrap_model(model).lang_model.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - + )[0] return loss -class NextTokenPrediction(NextTokenPredictionWithZLoss): - # same as NextTokenPredictionWithZLoss, but with z_loss_eps = 0 - @property - def name(self): - return "next_token_prediction" - - def __call__( - self, - model: VLM, - tokenizer, - images: torch.Tensor, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - autocast: callable, - ): - return super().__call__( - model=model, - tokenizer=tokenizer, - images=images, - input_ids=input_ids, - attention_mask=attention_mask, - autocast=autocast, - z_loss_eps=0, - ) - - def unwrap_model(model): """ Unwrap a model from a DataParallel or DistributedDataParallel wrapper. @@ -132,30 +87,4 @@ def unwrap_model(model): ): return model.module else: - return model - - -# From OpenLM (https://github.com/mlfoundations/open_lm/blob/main/open_lm/losses.py) -class CrossEntropyLossWithZLoss(CrossEntropyLoss): - def __init__( - self, - eps: float = 1e-4, - weight: Tensor = None, - size_average=None, - ignore_index: int = -100, - reduce=None, - reduction: str = "mean", - label_smoothing: float = 0, - ) -> None: - super().__init__( - weight, size_average, ignore_index, reduce, reduction, label_smoothing - ) - self.eps = eps - - def forward(self, input: Tensor, target: Tensor) -> Tensor: - if self.eps == 0: - return super().forward(input, target) - - return super().forward(input, target) + self.eps * torch.square( - torch.logsumexp(input, dim=-1).mean() - ) + return model \ No newline at end of file