Skip to content

[Question] No gain in VRAM usage with LigerFusedLinearCrossEntropyLoss #941

@mrinaldi97

Description

@mrinaldi97

Hello,
I am writing a codebase to train transformer models (almost finished, but it's still too early to share the entire framework) and I've just added Liger-Kernel support.
I was expecting a decreased RAM usage with the Fused Linear Cross Entropy implementation compared to naive torch's F.cross_entropy, however:

Torch => 22785MiB
Liger Fused => 22155MiB

Only 600MB saved. Vocab size is 32777 (maybe too small to see gain? Maybe because is not pow of 16?)

It's training a small test autoregressive model:

    "hidden_size": 768,
    "ffn_factor": 3.0,
    "num_hidden_layers": 12,
    "num_attention_heads": 12,

testing on an Nvidia 3090 GPU using Torch 2.6 cuda 12.4; training done in AMP with Pytorch Lightning, precision bf16.

Here is the transformer block with lm_head:

  class TransformerWithLMHead(nn.Module):
      """
      Adding an LM Head to TransformerWithEmbeddingHead. This is enough for Bert-like/GPT-like models.
      """
      def __init__(self,config: ModelConfig,cache=None):
          super().__init__()  
          self.cache = ensure_cache_and_registry(cache)    
          cache=self.cache           
          self.lm_head = ModuleWrapper(self.cache.registry.create("linear", "linear", in_features=config.hidden_size, out_features=config.vocab_size))
          self.transformer = TransformerWithEmbeddingHead(config,cache=cache)
          if config.tie_word_embeddings:
              self.lm_head.weight = self.transformer.embed_tokens.weight
          self.config=config
      def forward(self,x,return_type='logits',**kwargs):
          x=self.transformer(x,**kwargs)
          if return_type=='logits':
              return self.lm_head(x)
          else:
              return x

Here the relevant snippet from the training step:

        if self.loss_type=='fused':
            model_return_type = 'hidden'
            flattening_dimension = self.config.hidden_size
            loss_kwargs = {"lm_head_weight": self.model.lm_head.module.inner.weight}
            if hasattr(self.model.lm_head, "bias"):
                loss_kwargs["lm_head_bias"] = self.model.lm_head.module.inner.bias #TODO: Better way to access inner attributes of wrapped modules

And finally the way in which the Liger kernel is used:

@registry.register("loss", "cross_entropy_loss_fused", "liger", requires=["liger_kernel"], priority=0)
class LigerCrossEntropyLossFused(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        cls = _load("liger_kernel.transformers", "LigerFusedLinearCrossEntropyLoss")
        self.inner = cls(*args, **kwargs)

    def forward(self, hidden, targets, **kwargs):
        return self.inner(_input=hidden, target=targets, lin_weight=kwargs['lm_head_weight'], bias=kwargs.get("lm_head_bias", None))

Moreover, the loss diverges compared to torch:

Loss with Liger Loss with Torch

Is it implemented correctly?
Thank you

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions