-
Notifications
You must be signed in to change notification settings - Fork 438
Description
Hello,
I am writing a codebase to train transformer models (almost finished, but it's still too early to share the entire framework) and I've just added Liger-Kernel support.
I was expecting a decreased RAM usage with the Fused Linear Cross Entropy implementation compared to naive torch's F.cross_entropy, however:
Torch => 22785MiB
Liger Fused => 22155MiB
Only 600MB saved. Vocab size is 32777 (maybe too small to see gain? Maybe because is not pow of 16?)
It's training a small test autoregressive model:
"hidden_size": 768,
"ffn_factor": 3.0,
"num_hidden_layers": 12,
"num_attention_heads": 12,
testing on an Nvidia 3090 GPU using Torch 2.6 cuda 12.4; training done in AMP with Pytorch Lightning, precision bf16.
Here is the transformer block with lm_head:
class TransformerWithLMHead(nn.Module):
"""
Adding an LM Head to TransformerWithEmbeddingHead. This is enough for Bert-like/GPT-like models.
"""
def __init__(self,config: ModelConfig,cache=None):
super().__init__()
self.cache = ensure_cache_and_registry(cache)
cache=self.cache
self.lm_head = ModuleWrapper(self.cache.registry.create("linear", "linear", in_features=config.hidden_size, out_features=config.vocab_size))
self.transformer = TransformerWithEmbeddingHead(config,cache=cache)
if config.tie_word_embeddings:
self.lm_head.weight = self.transformer.embed_tokens.weight
self.config=config
def forward(self,x,return_type='logits',**kwargs):
x=self.transformer(x,**kwargs)
if return_type=='logits':
return self.lm_head(x)
else:
return xHere the relevant snippet from the training step:
if self.loss_type=='fused':
model_return_type = 'hidden'
flattening_dimension = self.config.hidden_size
loss_kwargs = {"lm_head_weight": self.model.lm_head.module.inner.weight}
if hasattr(self.model.lm_head, "bias"):
loss_kwargs["lm_head_bias"] = self.model.lm_head.module.inner.bias #TODO: Better way to access inner attributes of wrapped modulesAnd finally the way in which the Liger kernel is used:
@registry.register("loss", "cross_entropy_loss_fused", "liger", requires=["liger_kernel"], priority=0)
class LigerCrossEntropyLossFused(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
cls = _load("liger_kernel.transformers", "LigerFusedLinearCrossEntropyLoss")
self.inner = cls(*args, **kwargs)
def forward(self, hidden, targets, **kwargs):
return self.inner(_input=hidden, target=targets, lin_weight=kwargs['lm_head_weight'], bias=kwargs.get("lm_head_bias", None))Moreover, the loss diverges compared to torch:
Is it implemented correctly?
Thank you