-
Notifications
You must be signed in to change notification settings - Fork 19
Description
There has been significant refactoring of the loss functions for transformers 4.46
, that will render the cross entropy patching ineffective. Need to have a different ModelPatcherRule
for the new transformers version. CC: @anhuong
huggingface/transformers#34191
So now there are 3 possiblities
custom_loss_function
is passed intoTrainer
- model has migrated to the
custom_loss_function
API - model has not migrated (like Granite now)
For 3. This is the easy one, because it means no code changes
For 1. Im thinking we do not patch anything, because if a user wants to do this, we cant control what loss function they use
For 2. In this case we want to patch fixed_cross_entropy
, but this should be done on a per-model basis. So we need to somehow have the model instantiate the loss function, e.g., ForCausalLMLoss
, and only patch fixed_cross_entropy
during this instantiation process, and put it back to original after it is done