Skip to content

Commit 29aedec

Browse files
FIX Account for attention mask being a dict
Resolves CI errors such as this one: https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182 After resolving that error, other errors can occur, but they're unrelated and investigated independently. After the transformers change in huggingface/transformers#37866, it can happen that: > Models using different types of attention in different layers (i.e. gemma3) will now have a dict returned by prepare_inputd_for_generation (one dict entry per attention type) As PEFT operates on the attention mask for prompt learning methods, we need to adjust the code for the possibility of attention_mask being a dict. Right now, I simply extract the single value if the dict is just one element. For other sizes, I just raise an error, as I don't know how to deal with that. For our tests, this is enough but we might need to find a better solution in the future.
1 parent cc38f09 commit 29aedec

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

src/peft/peft_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1942,6 +1942,16 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
19421942
model_kwargs["input_ids"] = model_kwargs["input_ids"][:, -1:]
19431943

19441944
if (attention_mask := model_kwargs.get("attention_mask", None)) is not None:
1945+
if isinstance(attention_mask, dict):
1946+
# see: https://github.com/huggingface/transformers/pull/37866
1947+
# For now, just deal with the case of a single attention mask
1948+
if len(attention_mask) != 1:
1949+
raise ValueError(
1950+
f"Expected a single attention mask, got {len(attention_mask)} instead, please open an "
1951+
"issue (https://github.com/huggingface/peft/issues) and report the error."
1952+
)
1953+
attention_mask = list(attention_mask.values())[0]
1954+
19451955
size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
19461956
prefix_attention_mask = torch.ones(size).to(model_kwargs["input_ids"].device)
19471957
if attention_mask.dim() == 4:

0 commit comments

Comments
 (0)