Skip to content

Commit

Permalink
FIX: Make adaptation prompt CI happy for transformers 4.39.0 (#1551)
Browse files Browse the repository at this point in the history
* fix for transformers 4.39.0

* Update src/peft/tuners/adaption_prompt/utils.py

Co-authored-by: Benjamin Bossan <[email protected]>

---------

Co-authored-by: Benjamin Bossan <[email protected]>
  • Loading branch information
younesbelkada and BenjaminBossan authored Mar 11, 2024
1 parent 234bbab commit a1fe368
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/peft/tuners/adaption_prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=value_states.device)
position_ids = new_cache_positions.unsqueeze(0)

cos, sin = model.rotary_emb(value_states, seq_len=q_len + past_seen_tokens, position_ids=position_ids)
rotary_emb_kwargs = {"position_ids": position_ids}
# The `seq_len` argument has been officially removed in transformers >= 4.39.0
if "seq_len" in inspect.signature(model.rotary_emb.forward).parameters:
rotary_emb_kwargs["seq_len"] = q_len + past_seen_tokens

cos, sin = model.rotary_emb(value_states, **rotary_emb_kwargs)

# For batched inference unsqueeze it on the correct dim
# since: https://github.com/huggingface/transformers/pull/29109
Expand Down

0 comments on commit a1fe368

Please sign in to comment.