-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Description
System Info
- Transformers 4.47.0.dev0 (latest commit 33868a0)
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
#33932 may breaks the logic for the trainer's model_accepts_loss_kwargs. The llama model would not receive a num_items_in_batch argument, making the fix of #34283 invalid again
transformers/src/transformers/trainer.py
Line 605 in 33868a0
| self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model_forward).parameters |
Moreover, the names of keyword arguments are different for llama and other models, we might expect the same keyword arguments for different models.
transformers/src/transformers/models/llama/modeling_llama.py
Lines 1146 to 1161 in 33868a0
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| num_logits_to_keep: int = 0, | |
| **kwargs: Unpack[KwargsForCausalLM], | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
transformers/src/transformers/models/gemma/modeling_gemma.py
Lines 1015 to 1030 in 33868a0
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| num_logits_to_keep: int = 0, | |
| **loss_kwargs, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
Expected behavior
The models' forward functions should have a consistent keyword argument list.