We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4984f7b commit 5170e9dCopy full SHA for 5170e9d
src/transformers/generation/utils.py
@@ -656,6 +656,7 @@ def prepare_inputs_for_generation(
656
if causal_mask_creation_function is None: # can't be found
657
output_attentions = kwargs.get("output_attentions", False)
658
token_type_ids = getattr(model_input, "token_type_ids", None)
659
+ # Some models may overwrite the general one
660
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
661
attention_mask = causal_mask_creation_function(
662
config=self.config,
0 commit comments