Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to prefix tuning to fit transformers #2096

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7c8e287
[WIP] Fix to prefix tuning to fit transformers
BenjaminBossan Sep 25, 2024
b666532
Update src/peft/peft_model.py
BenjaminBossan Oct 22, 2024
73496ee
FIX: Change check if past_key_values is empty (#2106)
BenjaminBossan Sep 27, 2024
d60d1b6
DOC Update source install instruction (#2110)
Salehbigdeli Sep 30, 2024
faa4dd8
FIX Refactor OFT, small changes to BOFT (#1996)
Zeju1997 Oct 1, 2024
5cd5a45
ENH: Improved attribute access for modules_to_save (#2117)
BenjaminBossan Oct 2, 2024
0312b30
FIX low_cpu_mem_usage consolidates devices (#2113)
BenjaminBossan Oct 2, 2024
4c50892
TST Mark flaky X-LoRA test as xfail (#2114)
BenjaminBossan Oct 2, 2024
8699ba4
ENH: Warn when from_pretrained misses PEFT keys (#2118)
BenjaminBossan Oct 2, 2024
9ddc9f1
FEAT: Adding exclude modules param(#2044) (#2102)
JINO-ROHIT Oct 3, 2024
5a560da
FIX BC breaking change to boft conv2d scaling variable (#2127)
Zeju1997 Oct 7, 2024
d10151e
FEAT: VeRA quantization using bitsandbytes (#2070) (#2076)
ZiadHelal Oct 7, 2024
1d55d8b
Bump version to 0.13.2.dev0 (#2137)
BenjaminBossan Oct 8, 2024
98cf284
FEAT: Support torchao (#2062)
BenjaminBossan Oct 8, 2024
7961e8c
FIX: PiSSA now works with Conv1D layers (#2103) (#2104)
suyang160 Oct 8, 2024
fe8ba8e
FIX Type annoations in vera/bnb.py (#2139)
BenjaminBossan Oct 9, 2024
171cc75
ENH Make PEFT configs forward compatible (#2038)
BenjaminBossan Oct 9, 2024
858e1d2
FIX Raise mixed adapter infer with missing adapter (#2090)
BenjaminBossan Oct 9, 2024
b494d0e
FIX Prompt learning with latest transformers error (#2140)
BenjaminBossan Oct 9, 2024
f2d40e7
ENH LoRA notebook for NER task (#2126)
JINO-ROHIT Oct 10, 2024
7e5519a
FIX TST NaN issue with HQQ GPU test (#2143)
BenjaminBossan Oct 10, 2024
d0c22b3
FIX Bug in target module optimization if suffix (#2144)
BenjaminBossan Oct 10, 2024
3d205bc
Bump version to 0.13.2.dev0 (#2145)
BenjaminBossan Oct 11, 2024
7dfd956
FIX Don't assume past_key_valus for encoder models (#2149)
BenjaminBossan Oct 14, 2024
e74a6b9
FIX Use `SFTConfig` instead of `SFTTrainer` keyword args (#2150)
qgallouedec Oct 15, 2024
f481c5d
make style
BenjaminBossan Oct 22, 2024
9b223ea
Merge branch 'main' into fix-prefix-tuning-dynamic-cache
BenjaminBossan Oct 22, 2024
2d6f2fb
Some further fixes
BenjaminBossan Oct 23, 2024
f20652b
Fixes for seq2seq model
BenjaminBossan Oct 23, 2024
79250cc
Mistral test requires protobuf (???)
BenjaminBossan Oct 23, 2024
bb6131a
Also sentencepiece for mistral -_-
BenjaminBossan Oct 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
"datasets",
"diffusers",
"scipy",
"protobuf",
"sentencepiece",
]

setup(
Expand Down
32 changes: 27 additions & 5 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from safetensors import safe_open
from safetensors.torch import save_file as safe_save_file
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import PreTrainedModel
from transformers import Cache, DynamicCache, EncoderDecoderCache, PreTrainedModel
from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
from transformers.utils import PushToHubMixin

Expand Down Expand Up @@ -730,6 +730,18 @@ def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -
if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
past_key_values = post_process_fn(past_key_values)
elif peft_config.num_transformer_submodules == 1:
# Dont' apply this to encoder-decoder models and not to models requiring special processing.
# local import in case users use a very old transformers version
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
elif peft_config.num_transformer_submodules == 2 and self.base_model._supports_cache_class:
# Dont' apply this to encoder-decoder models that don't support new Cachc format yet
# If we don't apply this, prefix-tuning fails to update cross-attn cache
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
past_key_values.cross_attention_cache = DynamicCache()
past_key_values.is_updated = {
layer_idx: False for layer_idx in range(len(past_key_values.cross_attention_cache.key_cache))
}
return past_key_values
else:
if peft_config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:
Expand Down Expand Up @@ -2066,10 +2078,20 @@ def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
if peft_config.peft_type == PeftType.POLY:
model_kwargs["task_ids"] = task_ids
if model_kwargs.get("past_key_values", None) is None and peft_config.peft_type == PeftType.PREFIX_TUNING:
batch_size = model_kwargs["decoder_input_ids"].shape[0]
past_key_values = self.get_prompt(batch_size)
model_kwargs["past_key_values"] = past_key_values
elif peft_config.peft_type == PeftType.PREFIX_TUNING:
past_key_values = model_kwargs.get("past_key_values", None)
cache_position = model_kwargs.get("cache_position", [None])
# check prefill stage
is_prefill_stage = (
# old cache implementation
(past_key_values is None)
# new cache implementation
or (isinstance(past_key_values, Cache) and (cache_position[0] == 0))
)
if is_prefill_stage:
batch_size = model_kwargs["decoder_input_ids"].shape[0]
new_past_key_values = self.get_prompt(batch_size)
model_kwargs["past_key_values"] = new_past_key_values

return model_kwargs

Expand Down
41 changes: 40 additions & 1 deletion tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from unittest.mock import Mock, call, patch

import pytest
import torch
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)

from peft import (
AdaLoraConfig,
Expand Down Expand Up @@ -466,3 +474,34 @@ def test_prompt_learning_with_grouped_query_attention(self):
x = torch.tensor([[1, 2, 3]])
# does not raise
model(x)

def test_prefix_tuning_mistral(self):
# See issue 869, 1962
model_id = "hf-internal-testing/tiny-random-MistralForCausalLM"
base_model = AutoModelForCausalLM.from_pretrained(model_id)
peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM")
model = get_peft_model(base_model, peft_config)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

def process(samples):
tokenized = tokenizer(samples["quote"], truncation=True, max_length=128)
return tokenized

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(process, batched=True)

with tempfile.TemporaryDirectory() as tmp_dirname:
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
num_train_epochs=1,
max_steps=5,
per_device_train_batch_size=4,
output_dir=tmp_dirname,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()
4 changes: 2 additions & 2 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,8 +1601,8 @@ def get_output(model):

output_peft = get_output(peft_model)

# first check trivial case is not true that peft does not affect the output; for this to work, init_lora_weight
# must be False
# first check trivial case is not true that peft does not affect the output; for this to work, init_weight
# must be False (if the config supports it)
if isinstance(peft_model, StableDiffusionPipeline):
# for SD, check that most pixels have different values
assert (output_before != output_peft).float().mean() > 0.8
Expand Down
Loading