generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Closed
Labels
⚡ PEFTRelated to PEFTRelated to PEFT⚡accelerateRelated to accelerateRelated to accelerate🐛 bugSomething isn't workingSomething isn't working
Description
CI fails with dev dependencies: https://github.com/huggingface/trl/actions/runs/18400105862/job/52427161623
AttributeError: type object 'DynamicCache' has no attribute 'from_legacy_cache'
FAILED tests/test_sft_trainer.py::TestSFTTrainer::test_train_with_peft_config_prompt_tuning_1_prefix_tuning - AttributeError: type object 'DynamicCache' has no attribute 'from_legacy_cache'
Stacktrace:
tests/test_sft_trainer.py:547: in test_train_with_peft_config_prompt_tuning
trainer.train()
.venv/lib/python3.12/site-packages/transformers/trainer.py:2152: in train
return inner_training_loop(
.venv/lib/python3.12/site-packages/transformers/trainer.py:2483: in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/sft_trainer.py:1185: in training_step
return super().training_step(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/transformers/trainer.py:3762: in training_step
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/sft_trainer.py:1098: in compute_loss
(loss, outputs) = super().compute_loss(
.venv/lib/python3.12/site-packages/transformers/trainer.py:3829: in compute_loss
outputs = model(**inputs)
^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1784: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/accelerate/utils/operations.py:819: in forward
return model_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/accelerate/utils/operations.py:807: in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/amp/autocast_mode.py:44: in decorate_autocast
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/peft/peft_model.py:1890: in forward
kwargs["past_key_values"] = self.get_prompt(batch_size, max_cache_len=max_cache_len)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = PeftModelForCausalLM(
(base_model): Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(15166... (default): PrefixEncoder(
(embedding): Embedding(4, 16)
)
)
(word_embeddings): Embedding(151665, 8)
)
batch_size = 8, task_ids = None, max_cache_len = 19
def get_prompt(
self, batch_size: int, task_ids: Optional[torch.Tensor] = None, max_cache_len: Optional[int] = None
) -> torch.Tensor:
"""
Returns the virtual prompts to use for Peft. Only applicable when using a prompt learning method.
"""
peft_config = self.active_peft_config
prompt_encoder = self.prompt_encoder[self.active_adapter]
prompt_tokens = (
self.prompt_tokens[self.active_adapter]
.unsqueeze(0)
.expand(batch_size, -1)
.to(prompt_encoder.embedding.weight.device)
)
if peft_config.peft_type == PeftType.PREFIX_TUNING:
prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens]
if peft_config.inference_mode:
past_key_values = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
else:
past_key_values = prompt_encoder(prompt_tokens)
if self.base_model_torch_dtype is not None:
past_key_values = past_key_values.to(self.base_model_torch_dtype)
past_key_values = past_key_values.view(
batch_size,
peft_config.num_virtual_tokens,
peft_config.num_layers * 2,
peft_config.num_attention_heads,
peft_config.token_dim // peft_config.num_attention_heads,
)
if peft_config.num_transformer_submodules == 2:
past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
# Transpose: 2 x [num_layers, batch_size, num_heads, num_virtual_tokens, head_dim]
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
peft_config.num_transformer_submodules * 2
)
base_model = self.get_base_model()
model_config = getattr(base_model, "config", None)
model_type = getattr(model_config, "model_type", "")
if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
past_key_values = post_process_fn(past_key_values)
elif ("gemma2" in model_type) or ("gemma3_text" in model_type):
# Gemma2 and Gemma3 only support HybridCache (which does not have the from_legacy_cache method)
if max_cache_len is None:
raise ValueError(
"max_cache_len is None but it should have been passed. Something went wrong, please open an "
"issue on GitHub with a reproducer: https://github.com/huggingface/peft/issues"
)
base_config = base_model.config
if hasattr(base_config, "get_text_config"):
base_config = base_config.get_text_config()
new_cache = HybridCache(
base_config,
max_batch_size=batch_size,
max_cache_len=max_cache_len,
dtype=past_key_values[0].dtype,
device=past_key_values[0].device,
)
cache_position = torch.arange(peft_config.num_virtual_tokens, device=past_key_values[0].device)
for layer_idx in range(peft_config.num_layers):
key_states, value_states = past_key_values[0][layer_idx], past_key_values[1][layer_idx]
new_cache.update(
key_states, value_states, layer_idx, cache_kwargs={"cache_position": cache_position}
)
past_key_values = new_cache
elif peft_config.num_transformer_submodules == 1:
# Dont' apply this to encoder-decoder models and not to models requiring special processing.
# local import in case users use a very old transformers version
> past_key_values = DynamicCache.from_legacy_cache(past_key_values)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E AttributeError: type object 'DynamicCache' has no attribute 'from_legacy_cache'
.venv/lib/python3.12/site-packages/peft/peft_model.py:778: AttributeError
Metadata
Metadata
Assignees
Labels
⚡ PEFTRelated to PEFTRelated to PEFT⚡accelerateRelated to accelerateRelated to accelerate🐛 bugSomething isn't workingSomething isn't working