From 92b9e1aacdd9d3c49af188dd97c2b6a8afcad11d Mon Sep 17 00:00:00 2001 From: tsachiblau Date: Tue, 22 Oct 2024 10:57:11 +0200 Subject: [PATCH 01/14] added CPT model to peft --- src/peft/__init__.py | 2 + src/peft/mapping.py | 4 + src/peft/peft_model.py | 47 +++++ src/peft/tuners/__init__.py | 1 + src/peft/tuners/cpt/__init__.py | 20 ++ src/peft/tuners/cpt/config.py | 82 ++++++++ src/peft/tuners/cpt/model.py | 207 ++++++++++++++++++++ src/peft/utils/peft_types.py | 1 + tests/CPT_test.py | 332 ++++++++++++++++++++++++++++++++ 9 files changed, 696 insertions(+) create mode 100644 src/peft/tuners/cpt/__init__.py create mode 100644 src/peft/tuners/cpt/config.py create mode 100644 src/peft/tuners/cpt/model.py create mode 100644 tests/CPT_test.py diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 18f9cf4f43..b02dfa2264 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -91,6 +91,8 @@ HRAConfig, HRAModel, VBLoRAConfig, + CPTEmbedding, + CPTConfig, ) from .utils import ( TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, diff --git a/src/peft/mapping.py b/src/peft/mapping.py index 5eb19dcb5b..f0acaaf394 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -38,6 +38,8 @@ AdaptionPromptConfig, BOFTConfig, BOFTModel, + CPTConfig, + CPTEmbedding, FourierFTConfig, FourierFTModel, HRAConfig, @@ -104,6 +106,7 @@ "XLORA": XLoraConfig, "HRA": HRAConfig, "VBLORA": VBLoRAConfig, + "CPT": CPTConfig, } PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[BaseTuner]] = { @@ -121,6 +124,7 @@ "XLORA": XLoraModel, "HRA": HRAModel, "VBLORA": VBLoRAModel, + "CPT": CPTEmbedding, } diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index b44fcfa601..aec0846563 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -45,6 +45,7 @@ AdaLoraModel, AdaptionPromptModel, BOFTModel, + CPTEmbedding, FourierFTModel, HRAModel, IA3Model, @@ -102,6 +103,7 @@ PeftType.XLORA: XLoraModel, PeftType.HRA: HRAModel, PeftType.VBLORA: VBLoRAModel, + PeftType.CPT: CPTEmbedding, } @@ -604,6 +606,8 @@ def _setup_prompt_encoder(self, adapter_name: str): prompt_encoder = PromptEncoder(config) elif config.peft_type == PeftType.PREFIX_TUNING: prompt_encoder = PrefixEncoder(config) + elif config.peft_type == PeftType.CPT: + prompt_encoder = CPTEmbedding(config, self.word_embeddings) else: raise ValueError("Not supported") @@ -1627,6 +1631,49 @@ def forward( # overwrite past_kv in kwargs kwargs["past_key_values"] = self.get_prompt(batch_size) return self.base_model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs) + elif peft_config.peft_type == PeftType.CPT: + if peft_config.CPT_prompt_tuning_init == "TEXT": + CPT_token_ids = peft_config.CPT_token_ids + CPT_tokens_type_mask = peft_config.CPT_tokens_type_mask + else: + CPT_token_ids = [0] * peft_config.num_virtual_tokens + CPT_tokens_type_mask = [0] * peft_config.num_virtual_tokens + + # Extract input_type_mask from kwargs and move it to the same device as labels + input_type_mask = kwargs.pop("input_type_mask").to(labels.device) + # Generate embeddings if not provided + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + # Get prompt and concatenate with input embeddings + prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids) + prompts = prompts.to(inputs_embeds.dtype) + inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) + # If labels are provided, generate prefix labels and type mask + if labels is not None: + # Generate prefix labels and concatenate with the input labels + prefix_labels = torch.Tensor(CPT_token_ids).long().view(1, -1) + prefix_labels = prefix_labels.repeat(batch_size, 1).to(labels.device) + CPT_labels = torch.cat((prefix_labels, labels), dim=1) + # Generate prefix type mask and shift input type mask values to avoid conflicts + prefix_type_mask = torch.Tensor(CPT_tokens_type_mask).long().view(1, -1) + prefix_type_mask = prefix_type_mask.repeat(batch_size, 1).to(labels.device) + adjusted_input_type_mask = input_type_mask + adjusted_input_type_mask[adjusted_input_type_mask > 0] += prefix_type_mask.max() + # Concatenate prefix and shifted input type masks + CPT_type_mask = torch.cat((prefix_type_mask, adjusted_input_type_mask), dim=1) + # Identify valid label positions and mask invalid ones with -100 + labels_idx = (CPT_type_mask > 0) & (CPT_type_mask % 4 == 0) + CPT_labels[~labels_idx] = -100 + # Update kwargs with the modified labels + kwargs["labels"] = CPT_labels + # Pass the modified inputs to the base model + base_model_output = self.base_model(inputs_embeds=inputs_embeds, **kwargs) + # Calculate the loss using the custom CPT loss function + base_model_output = CPTEmbedding.calculate_loss( + base_model_output, CPT_labels, CPT_type_mask, self.peft_config["default"] + ) + + return base_model_output else: if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index d58ff9e3e6..f463794b28 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -37,3 +37,4 @@ from .xlora import XLoraConfig, XLoraModel from .hra import HRAConfig, HRAModel from .vblora import VBLoRAConfig, VBLoRAModel +from .cpt import CPTConfig, CPTEmbedding \ No newline at end of file diff --git a/src/peft/tuners/cpt/__init__.py b/src/peft/tuners/cpt/__init__.py new file mode 100644 index 0000000000..67b200cce3 --- /dev/null +++ b/src/peft/tuners/cpt/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .config import CPTConfig +from .model import CPTEmbedding + + +__all__ = ["CPTConfig", "CPTEmbedding"] diff --git a/src/peft/tuners/cpt/config.py b/src/peft/tuners/cpt/config.py new file mode 100644 index 0000000000..98be0c5597 --- /dev/null +++ b/src/peft/tuners/cpt/config.py @@ -0,0 +1,82 @@ +import enum +from dataclasses import dataclass, field +from typing import Optional + +import torch + +from peft.config import PeftConfig +from peft.utils import PeftType + + +class PromptTuningInit(str, enum.Enum): + """Enum for specifying the initialization method for prompt tuning.""" + + TEXT = "TEXT" # Initialize using text-based embeddings. + RANDOM = "RANDOM" # Initialize randomly. + + +@dataclass +class CPTConfig(PeftConfig): + """ + CPT Configuration class extending PeftConfig for Context-aware Prompt Tuning (CPT). + + This class introduces additional parameters required for CPT, such as token type masks, + prompt tuning initialization, loss weighting, and projection settings. + """ + + # Token-related configurations + CPT_token_ids: Optional[torch.Tensor] = field( + default=None, metadata={"help": "Tensor of token IDs used for CPT prompts."} + ) + CPT_mask: Optional[torch.Tensor] = field(default=None, metadata={"help": "Tensor mask applied to CPT tokens."}) + CPT_tokens_type_mask: Optional[bool] = field( + default=None, metadata={"help": "Mask indicating the type of each CPT token."} + ) + + # Prompt tuning initialization method + CPT_prompt_tuning_init: Optional[str] = field( + default="TEXT", metadata={"help": "Initialization method: 'TEXT' for embedding-based, 'RANDOM' for random."} + ) + + # Loss-related configurations + opt_weighted_loss_type: Optional[str] = field( + default="none", metadata={"help": "Type of weighted loss: 'none' or 'decay'."} + ) + opt_loss_decay_factor: Optional[float] = field( + default=1.0, metadata={"help": "Factor for exponential decay in loss weighting."} + ) + + # Projection-related configurations + opt_projection_epsilon: Optional[float] = field( + default=0.1, metadata={"help": "Epsilon value for input projection."} + ) + opt_projection_format_epsilon: Optional[float] = field( + default=0.1, metadata={"help": "Epsilon value for format projection."} + ) + + # Tokenizer configuration + tokenizer_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" + }, + ) + + # Virtual token configurations + num_virtual_tokens: int = field(default=0, metadata={"help": "Number of virtual tokens used in the prompt."}) + + # CPT-specific static attributes + is_prompt_learning = True # Indicates that CPT is a prompt-learning method. + num_layers = None # Number of layers (optional, not always required). + token_dim = None # Dimension of token embeddings. + num_attention_heads = None # Number of attention heads (if applicable). + task_type = "CAUSAL_LM" # Specifies that CPT is used for causal language modeling. + num_transformer_submodules = 1 # Number of transformer submodules used. + + def __post_init__(self): + """ + Post-initialization hook to set additional attributes after the config is initialized. + """ + self.peft_type = PeftType.CPT # Specifies that the PEFT type is CPT. + self.target_modules = None # Placeholder for target modules in CPT. + self.task_type = "CAUSAL_LM" # Ensures task type is causal language modeling. diff --git a/src/peft/tuners/cpt/model.py b/src/peft/tuners/cpt/model.py new file mode 100644 index 0000000000..dceb30590e --- /dev/null +++ b/src/peft/tuners/cpt/model.py @@ -0,0 +1,207 @@ +import copy + +import torch +from torch.nn import CrossEntropyLoss + +from peft.utils.integrations import gather_params_ctx + +from .config import PromptTuningInit + + +class CPTEmbedding(torch.nn.Module): + """ + CPTEmbedding is a custom embedding layer designed for Context-aware Prompt Tuning (CPT) in PEFT. + It initializes embeddings, applies prompt-specific projections, and computes loss using label masks. + """ + + def __init__(self, config, word_embeddings): + """ + Initializes the CPTEmbedding module. + + Args: + config (Namespace): Configuration object containing model hyperparameters and CPT-specific settings. + word_embeddings (torch.nn.Embedding): The base word embedding layer used to initialize CPT embeddings. + """ + super().__init__() + self.config = copy.deepcopy(config) + self.check_config() + num_virtual_tokens = config.num_virtual_tokens + + # Initialize embeddings with virtual token dimensions + self.embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim) + + # Initialize embeddings using text-based prompt tuning, if configured + if config.CPT_prompt_tuning_init == PromptTuningInit.TEXT and not config.inference_mode: + assert config.num_virtual_tokens == len(config.CPT_token_ids) + + init_token_ids = torch.LongTensor(config.CPT_token_ids).to(word_embeddings.weight.device) + with gather_params_ctx(word_embeddings.parameters()): + word_embedding_weights = word_embeddings(init_token_ids).detach().clone() + word_embedding_weights = word_embedding_weights.to(torch.float32) + self.embedding.weight = torch.nn.Parameter(word_embedding_weights) + + # Initialize delta embedding with zero weights + self.delta_embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim) + self.delta_embedding.weight.data = torch.zeros_like(self.delta_embedding.weight).to(torch.float32) + + # Apply hook for backward gradient updates + self.set_updated_tokens() + + def forward(self, indices): + """ + Computes the prompt embeddings and applies delta adjustments. + + Args: + indices (torch.Tensor): Indices of the tokens to be embedded. + + Returns: + torch.Tensor: Sum of prompt embeddings and delta embeddings. + """ + with torch.no_grad(): + prompt_embeddings = self.embedding(indices) + + self.projection() # Apply epsilon-based projection + delta_prompt_embeddings = self.delta_embedding(indices) + + return prompt_embeddings + delta_prompt_embeddings + + def set_updated_tokens(self): + """ + Sets up a backward hook to selectively update token gradients based on the CPT token type mask. + """ + if self.config.CPT_prompt_tuning_init == PromptTuningInit.TEXT: + tensor_ICL_mask = torch.Tensor(self.config.CPT_tokens_type_mask).long() + mask_input_template = torch.remainder(tensor_ICL_mask, 4) == 1 + mask_input = torch.remainder(tensor_ICL_mask, 4) == 2 + mask_output_template = torch.remainder(tensor_ICL_mask, 4) == 3 + mask = mask_input_template | mask_input | mask_output_template + mask = mask.view(-1, 1) + elif self.config.CPT_prompt_tuning_init == PromptTuningInit.RANDOM: + mask = torch.ones((self.config.num_virtual_tokens, 1)).long() + + def backward_hook(grad): + grad = grad * mask.to(grad.device) # Apply mask to gradients + return grad + + self.delta_embedding.weight.register_hook(backward_hook) + + def get_epsilon(self): + if self.config.CPT_prompt_tuning_init == "TEXT": + CPT_tokens_type_mask = self.config.CPT_tokens_type_mask + else: + CPT_tokens_type_mask = [2] * self.config.num_virtual_tokens + + MIN_VALUE = 1e-10 + + # Calculate normalized epsilon values for input, output, and format tokens + normalized_format_eps = ( + self.config.opt_projection_format_epsilon + * torch.sqrt( + torch.Tensor([self.config.token_dim / 2048]) + ) + ) + normalized_input_eps = self.config.opt_projection_epsilon * torch.sqrt( + torch.Tensor([self.config.token_dim / 2048]) + ) + + epsilon = torch.ones_like(torch.Tensor(CPT_tokens_type_mask)).to(torch.float32) * MIN_VALUE + CPT_tokens_type_mask = torch.Tensor(CPT_tokens_type_mask).long() + + epsilon[(CPT_tokens_type_mask > 0) & (torch.remainder(CPT_tokens_type_mask, 4) == 1)] = normalized_format_eps + epsilon[(CPT_tokens_type_mask > 0) & (torch.remainder(CPT_tokens_type_mask, 4) == 3)] = normalized_format_eps + epsilon[(CPT_tokens_type_mask > 0) & (torch.remainder(CPT_tokens_type_mask, 4) == 2)] = normalized_input_eps + + return epsilon + + def projection(self): + """ + Applies epsilon-based projection to the delta embeddings to control their norm. + """ + + # Apply projection to control delta embedding norm + with torch.no_grad(): + new_embeddings_weights = self.delta_embedding.weight.clone().to(self.delta_embedding.weight.device) + token_norm = torch.norm(new_embeddings_weights, p=2, dim=1) + + projection_mask = token_norm > 0 + if torch.any(projection_mask): + epsilon = self.get_epsilon().to(self.delta_embedding.weight.device) + new_embeddings_weights[projection_mask] *= ( + epsilon[projection_mask] / (token_norm[projection_mask].clamp(min=epsilon[projection_mask])) + ).view(-1, 1) + self.delta_embedding.weight.data = new_embeddings_weights + + @staticmethod + def calculate_loss(base_model_output, labels, CPT_type_mask, config): + """ + Computes the loss for CPT models with optional exponential decay. + + Args: + base_model_output (ModelOutput): Output from the base model containing logits. + labels (torch.Tensor): Ground-truth labels for the input tokens. + CPT_type_mask (torch.Tensor): Token type mask used for filtering valid loss terms. + config (Namespace): Configuration object containing loss-related hyperparameters. + + Returns: + ModelOutput: The base model output with computed loss. + """ + + if config.opt_weighted_loss_type in ["decay"]: + device = base_model_output.logits.device + + lm_logits = base_model_output.logits + labels = labels.to(device) + + # Shift logits and labels for token prediction + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_CPT_type_mask = CPT_type_mask[..., 1:].contiguous() + + shift_labels_bool = (shift_labels.clone().detach() != -100).bool() + batch_size, seq_length, vocab_size = shift_logits.shape + + # Compute cross-entropy loss + loss_fct = CrossEntropyLoss(reduction="none", ignore_index=-100) + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + ) + loss = loss.view(batch_size, seq_length) + # Apply exponential decay weights to the loss + shift_labels_weights = shift_labels_bool.clone().detach().float() + for i in range(batch_size): + idx_labels = (shift_CPT_type_mask[i] > 0) & (shift_CPT_type_mask[i] % 4 == 0) + labels_ids = shift_CPT_type_mask[i][idx_labels].unique() + + exponential_decay = torch.ones_like(shift_CPT_type_mask[i]).to(device=device).float() + decay_value = 1 + for label_mask_idx in torch.flip(labels_ids, [0]): + exponential_decay[shift_CPT_type_mask[i] == label_mask_idx] = decay_value + decay_value *= config.opt_loss_decay_factor + shift_labels_weights[i] *= exponential_decay + + # Compute the weighted mean loss + loss = (loss[shift_labels_bool] * shift_labels_weights[shift_labels_bool]).mean() + base_model_output.loss = loss + elif config.opt_weighted_loss_type not in ["none"]: + raise NotImplementedError(f"Loss type '{config.opt_weighted_loss_type}' not implemented.") + + return base_model_output + + def check_config(self): + if self.config.CPT_prompt_tuning_init == PromptTuningInit.TEXT: + assert self.config.CPT_token_ids is not None + assert self.config.CPT_mask is not None + assert self.config.CPT_tokens_type_mask is not None + assert ( + len(self.config.CPT_token_ids) + == len(self.config.CPT_mask) + == len(self.config.CPT_tokens_type_mask) + == self.config.num_virtual_tokens + ) + elif self.config.CPT_prompt_tuning_init == PromptTuningInit.RANDOM: + assert self.config.CPT_token_ids is None + assert self.config.CPT_mask is None + assert self.config.CPT_tokens_type_mask is None + assert self.config.num_virtual_tokens > 0 + else: + raise NotImplementedError(f" was not implemented for {self.config.CPT_prompt_tuning_init}") diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index 4072878700..02022439f5 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -63,6 +63,7 @@ class PeftType(str, enum.Enum): XLORA = "XLORA" HRA = "HRA" VBLORA = "VBLORA" + CPT = "CPT" class TaskType(str, enum.Enum): diff --git a/tests/CPT_test.py b/tests/CPT_test.py new file mode 100644 index 0000000000..18862f867c --- /dev/null +++ b/tests/CPT_test.py @@ -0,0 +1,332 @@ +from typing import Any, Dict, List, Union + +import pytest +import torch +from datasets import load_dataset +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + Trainer, + TrainingArguments, +) + +from peft import CPTConfig, get_peft_model + + +TEMPLATE = {"input": "input: {}", "intra_seperator": " ", "output": "output: {}", "inter_seperator": "\n"} + +MODEL_NAME = "bigscience/bloom-1b7" +MAX_INPUT_LENGTH = 1024 + + +@pytest.fixture(scope="module") +def global_tokenizer(): + """Load the tokenizer fixture for the model.""" + + return AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=".", padding_side="right", trust_remote_code=True) + + +@pytest.fixture(scope="module") +def config_text(): + """Load the SST2 dataset and prepare it for testing.""" + config = CPTConfig( + CPT_token_ids=[0, 1, 2, 3, 4, 5, 6, 7], # Example token IDs for testing + CPT_mask=[1, 1, 1, 1, 1, 1, 1, 1], + CPT_tokens_type_mask=[1, 2, 2, 2, 3, 3, 3, 4], + CPT_prompt_tuning_init="TEXT", + num_virtual_tokens=8, + opt_weighted_loss_type="decay", + opt_loss_decay_factor=0.95, + opt_projection_epsilon=0.2, + opt_projection_format_epsilon=0.1, + tokenizer_name_or_path=MODEL_NAME, + ) + return config + + +@pytest.fixture(scope="module") +def config_random(): + """Load the SST2 dataset and prepare it for testing.""" + config = CPTConfig( + CPT_prompt_tuning_init="RANDOM", + num_virtual_tokens=8, + opt_weighted_loss_type="decay", + opt_loss_decay_factor=0.95, + opt_projection_epsilon=0.2, + opt_projection_format_epsilon=0.1, + tokenizer_name_or_path=MODEL_NAME, + ) + return config + + +@pytest.fixture(scope="module") +def sst_data(): + """Load the SST2 dataset and prepare it for testing.""" + data = load_dataset("glue", "sst2") + + def add_string_labels(example): + if example["label"] == 0: + example["label_text"] = "negative" + elif example["label"] == 1: + example["label_text"] = "positive" + return example + + train_dataset = data["train"].select(range(4)).map(add_string_labels) + test_dataset = data["validation"].select(range(10)).map(add_string_labels) + + return {"train": train_dataset, "test": test_dataset} + + +@pytest.fixture(scope="module") +def collator(global_tokenizer): + class CPTDataCollatorForLanguageModeling(DataCollatorForLanguageModeling): + def __init__(self, tokenizer, training=True, mlm=False): + super().__init__(tokenizer, mlm=mlm) + self.training = training + self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # mk check why needed + + def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: + # Handle dict or lists with proper padding and conversion to tensor. + list_sample_mask = [] + for i in range(len(examples)): + if "sample_mask" in examples[i].keys(): + list_sample_mask.append(examples[i].pop("sample_mask")) + + max_len = max(len(ex["input_ids"]) for ex in examples) + + def pad_sequence(sequence, max_len, pad_value=0): + return sequence + [pad_value] * (max_len - len(sequence)) + + input_ids = torch.tensor([pad_sequence(ex["input_ids"], max_len) for ex in examples]) + attention_mask = torch.tensor([pad_sequence(ex["attention_mask"], max_len) for ex in examples]) + input_type_mask = torch.tensor([pad_sequence(ex["input_type_mask"], max_len) for ex in examples]) + + batch = {"input_ids": input_ids, "attention_mask": attention_mask, "input_type_mask": input_type_mask} + + tensor_sample_mask = batch["input_ids"].clone().long() + tensor_sample_mask[:, :] = 0 + for i in range(len(list_sample_mask)): + tensor_sample_mask[i, : len(list_sample_mask[i])] = list_sample_mask[i] + + batch["labels"] = batch["input_ids"].clone() + if not self.training: + batch["sample_mask"] = tensor_sample_mask + + return batch + + collator = CPTDataCollatorForLanguageModeling(global_tokenizer, training=True, mlm=False) + return collator + + +def dataset(data, tokenizer): + class CPTDataset(Dataset): + def __init__(self, samples, tokenizer, template, max_length=MAX_INPUT_LENGTH): + self.template = template + self.tokenizer = tokenizer + self.max_length = max_length + + self.attention_mask = [] + self.input_ids = [] + self.input_type_mask = [] + self.inter_seperator_ids = self._get_input_ids(template["inter_seperator"]) + + for sample_i in tqdm(samples): + input_text, label = sample_i["sentence"], sample_i["label_text"] + input_ids, attention_mask, input_type_mask = self.preprocess_sentence(input_text, label) + + self.input_ids.append(input_ids) + self.attention_mask.append(attention_mask) + self.input_type_mask.append(input_type_mask) + + def _get_input_ids(self, text): + return self.tokenizer(text, add_special_tokens=False)["input_ids"] + + def preprocess_sentence(self, input_text, label): + input_template_part_1_text, input_template_part_2_text = self.template["input"].split("{}") + input_template_tokenized_part1 = self._get_input_ids(input_template_part_1_text) + input_tokenized = self._get_input_ids(input_text) + input_template_tokenized_part2 = self._get_input_ids(input_template_part_2_text) + + sep_tokenized = self._get_input_ids(self.template["intra_seperator"]) + + label_template_part_1, label_template_part_2 = self.template["output"].split("{}") + label_template_part1_tokenized = self._get_input_ids(label_template_part_1) + label_tokenized = self._get_input_ids(label) + label_template_part2_tokenized = self._get_input_ids(label_template_part_2) + + eos = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else [] + input_ids = ( + input_template_tokenized_part1 + + input_tokenized + + input_template_tokenized_part2 + + sep_tokenized + + label_template_part1_tokenized + + label_tokenized + + label_template_part2_tokenized + + eos + ) + + # determine label tokens, to calculate loss only over them when labels_loss == True + attention_mask = [1] * len(input_ids) + input_type_mask = ( + [1] * len(input_template_tokenized_part1) + + [2] * len(input_tokenized) + + [1] * len(input_template_tokenized_part2) + + [0] * len(sep_tokenized) + + [3] * len(label_template_part1_tokenized) + + [4] * len(label_tokenized) + + [3] * len(label_template_part2_tokenized) + + [0] * len(eos) + ) + + assert len(input_type_mask) == len(input_ids) == len(attention_mask) + + return input_ids, attention_mask, input_type_mask + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "input_type_mask": self.input_type_mask[idx], + } + + dataset = CPTDataset(data, tokenizer, TEMPLATE) + + return dataset + + +def test_model_initialization_text(global_tokenizer, config_text): + """Test model loading and PEFT model initialization.""" + base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=".", trust_remote_code=True) + + model = get_peft_model(base_model, config_text) + assert model is not None, "PEFT model initialization failed" + + +def test_model_initialization_random(global_tokenizer, config_random): + """Test model loading and PEFT model initialization.""" + base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=".", trust_remote_code=True) + + model = get_peft_model(base_model, config_random) + assert model is not None, "PEFT model initialization failed" + + +def test_model_training_random(sst_data, global_tokenizer, collator, config_random): + """Perform a short training run to verify the model and data integration.""" + + base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=".", trust_remote_code=True) + model = get_peft_model(base_model, config_random) + emb = model.prompt_encoder.default.embedding.weight.data.clone().detach() + training_args = TrainingArguments( + output_dir="./results", + per_device_train_batch_size=1, + num_train_epochs=2, + remove_unused_columns=False, + save_strategy="no", + logging_steps=1, + ) + + train_dataset = dataset(sst_data["train"], global_tokenizer) + + trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=collator) + + try: + trainer.train() + except Exception as e: + pytest.fail(f"Training failed with error: {e}") + + assert torch.all(model.prompt_encoder.default.embedding.weight.data.clone().detach().cpu() == emb.cpu()) + + model.prompt_encoder.default.projection() + delta_emb = model.prompt_encoder.default.delta_embedding.weight.data.clone().detach() + norm_delta = delta_emb.norm(dim=1).cpu() + epsilon = model.prompt_encoder.default.get_epsilon().cpu() + assert torch.all(norm_delta <= epsilon) + + +def test_model_training_text(sst_data, global_tokenizer, collator, config_text): + """Perform a short training run to verify the model and data integration.""" + + base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=".", trust_remote_code=True) + model = get_peft_model(base_model, config_text) + emb = model.prompt_encoder.default.embedding.weight.data.clone().detach() + + training_args = TrainingArguments( + output_dir="./results", + per_device_train_batch_size=1, + num_train_epochs=2, + remove_unused_columns=False, + save_strategy="no", + logging_steps=1, + ) + + train_dataset = dataset(sst_data["train"], global_tokenizer) + + trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=collator) + + try: + trainer.train() + except Exception as e: + pytest.fail(f"Training failed with error: {e}") + + assert torch.all(model.prompt_encoder.default.embedding.weight.data.clone().detach().cpu() == emb.cpu()) + + delta_emb = model.prompt_encoder.default.delta_embedding.weight.data.clone().detach() + norm_delta = delta_emb.norm(dim=1).cpu() + CPT_tokens_type_mask = torch.Tensor(config_text.CPT_tokens_type_mask).long() + non_label_idx = (CPT_tokens_type_mask == 1) | (CPT_tokens_type_mask == 2) | (CPT_tokens_type_mask == 3) + assert torch.all((norm_delta > 0) == non_label_idx) + + model.prompt_encoder.default.projection() + delta_emb = model.prompt_encoder.default.delta_embedding.weight.data.clone().detach() + norm_delta = delta_emb.norm(dim=1).cpu() + epsilon = model.prompt_encoder.default.get_epsilon().cpu() + assert torch.all(norm_delta <= epsilon) + assert torch.all((norm_delta == 0) == (~non_label_idx)) + + +def test_model_batch_training_text(sst_data, global_tokenizer, collator, config_text): + """Perform a short training run to verify the model and data integration.""" + + base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=".", trust_remote_code=True) + model = get_peft_model(base_model, config_text) + emb = model.prompt_encoder.default.embedding.weight.data.clone().detach() + + training_args = TrainingArguments( + output_dir="./results", + per_device_train_batch_size=2, + num_train_epochs=2, + remove_unused_columns=False, + save_strategy="no", + logging_steps=1, + ) + + train_dataset = dataset(sst_data["train"], global_tokenizer) + + trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=collator) + + try: + trainer.train() + except Exception as e: + pytest.fail(f"Training failed with error: {e}") + + assert torch.all(model.prompt_encoder.default.embedding.weight.data.clone().detach().cpu() == emb.cpu()) + + delta_emb = model.prompt_encoder.default.delta_embedding.weight.data.clone().detach() + norm_delta = delta_emb.norm(dim=1).cpu() + CPT_tokens_type_mask = torch.Tensor(config_text.CPT_tokens_type_mask).long() + non_label_idx = (CPT_tokens_type_mask == 1) | (CPT_tokens_type_mask == 2) | (CPT_tokens_type_mask == 3) + assert torch.all((norm_delta > 0) == non_label_idx) + + model.prompt_encoder.default.projection() + delta_emb = model.prompt_encoder.default.delta_embedding.weight.data.clone().detach() + norm_delta = delta_emb.norm(dim=1).cpu() + epsilon = model.prompt_encoder.default.get_epsilon().cpu() + assert torch.all(norm_delta <= epsilon) + assert torch.all((norm_delta == 0) == (~non_label_idx)) From 2dfe70fa244b89324c2661264e44d814179cc05d Mon Sep 17 00:00:00 2001 From: tsachiblau Date: Fri, 25 Oct 2024 18:35:44 +0200 Subject: [PATCH 02/14] Added arXiv link to the paper, integrated CPT into testing framework, created _cpt_forward for readability, updated copyright to 2024, renamed class to CPTPromptInit, changed config variables to lowercase and list[int], removed exception catch from tests, added assertion docs, removed batch_size=1 test, and renamed test file to test_cpt.py. --- src/peft/peft_model.py | 91 ++++++++++++++++-------------- src/peft/tuners/cpt/__init__.py | 2 +- src/peft/tuners/cpt/config.py | 36 +++++++++--- src/peft/tuners/cpt/model.py | 82 ++++++++++++++++----------- tests/{CPT_test.py => test_cpt.py} | 89 +++++++++-------------------- tests/testing_common.py | 7 +++ 6 files changed, 161 insertions(+), 146 deletions(-) rename tests/{CPT_test.py => test_cpt.py} (80%) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 8b02755d47..eca7953b1d 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1723,48 +1723,7 @@ def forward( kwargs["past_key_values"] = self.get_prompt(batch_size) return self.base_model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs) elif peft_config.peft_type == PeftType.CPT: - if peft_config.CPT_prompt_tuning_init == "TEXT": - CPT_token_ids = peft_config.CPT_token_ids - CPT_tokens_type_mask = peft_config.CPT_tokens_type_mask - else: - CPT_token_ids = [0] * peft_config.num_virtual_tokens - CPT_tokens_type_mask = [0] * peft_config.num_virtual_tokens - - # Extract input_type_mask from kwargs and move it to the same device as labels - input_type_mask = kwargs.pop("input_type_mask").to(labels.device) - # Generate embeddings if not provided - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - # Get prompt and concatenate with input embeddings - prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids) - prompts = prompts.to(inputs_embeds.dtype) - inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) - # If labels are provided, generate prefix labels and type mask - if labels is not None: - # Generate prefix labels and concatenate with the input labels - prefix_labels = torch.Tensor(CPT_token_ids).long().view(1, -1) - prefix_labels = prefix_labels.repeat(batch_size, 1).to(labels.device) - CPT_labels = torch.cat((prefix_labels, labels), dim=1) - # Generate prefix type mask and shift input type mask values to avoid conflicts - prefix_type_mask = torch.Tensor(CPT_tokens_type_mask).long().view(1, -1) - prefix_type_mask = prefix_type_mask.repeat(batch_size, 1).to(labels.device) - adjusted_input_type_mask = input_type_mask - adjusted_input_type_mask[adjusted_input_type_mask > 0] += prefix_type_mask.max() - # Concatenate prefix and shifted input type masks - CPT_type_mask = torch.cat((prefix_type_mask, adjusted_input_type_mask), dim=1) - # Identify valid label positions and mask invalid ones with -100 - labels_idx = (CPT_type_mask > 0) & (CPT_type_mask % 4 == 0) - CPT_labels[~labels_idx] = -100 - # Update kwargs with the modified labels - kwargs["labels"] = CPT_labels - # Pass the modified inputs to the base model - base_model_output = self.base_model(inputs_embeds=inputs_embeds, **kwargs) - # Calculate the loss using the custom CPT loss function - base_model_output = CPTEmbedding.calculate_loss( - base_model_output, CPT_labels, CPT_type_mask, self.peft_config["default"] - ) - - return base_model_output + return self._cpt_forward(input_ids, inputs_embeds, peft_config, task_ids, batch_size, **kwargs) else: if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) @@ -1777,6 +1736,54 @@ def forward( inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) return self.base_model(inputs_embeds=inputs_embeds, **kwargs) + + def _cpt_forward(self, input_ids=None, inputs_embeds=None, peft_config=None, task_ids=None, batch_size=None, **kwargs): + # Extract labels from kwargs + labels = kwargs.pop("labels") + # Extract input_type_mask from kwargs and move it to the same device as labels + input_type_mask = kwargs.pop("input_type_mask").to(labels.device) + + if peft_config.cpt_prompt_tuning_init == "TEXT": + cpt_token_ids = peft_config.cpt_token_ids + cpt_tokens_type_mask = peft_config.cpt_tokens_type_mask + else: + cpt_token_ids = [0] * peft_config.num_virtual_tokens + cpt_tokens_type_mask = [0] * peft_config.num_virtual_tokens + + # Generate embeddings if not provided + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + # Get prompt and concatenate with input embeddings + prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids) + prompts = prompts.to(inputs_embeds.dtype) + inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) + # If labels are provided, generate prefix labels and type mask + if labels is not None: + # Generate prefix labels and concatenate with the input labels + prefix_labels = torch.Tensor(cpt_token_ids).long().view(1, -1) + prefix_labels = prefix_labels.repeat(batch_size, 1).to(labels.device) + cpt_labels = torch.cat((prefix_labels, labels), dim=1) + # Generate prefix type mask and shift input type mask values to avoid conflicts + prefix_type_mask = torch.Tensor(cpt_tokens_type_mask).long().view(1, -1) + prefix_type_mask = prefix_type_mask.repeat(batch_size, 1).to(labels.device) + adjusted_input_type_mask = input_type_mask + adjusted_input_type_mask[adjusted_input_type_mask > 0] += prefix_type_mask.max() + # Concatenate prefix and shifted input type masks + cpt_type_mask = torch.cat((prefix_type_mask, adjusted_input_type_mask), dim=1) + # Identify valid label positions and mask invalid ones with -100 + labels_idx = (cpt_type_mask > 0) & (cpt_type_mask % 4 == 0) + cpt_labels[~labels_idx] = -100 + # Update kwargs with the modified labels + kwargs["labels"] = cpt_labels + # Pass the modified inputs to the base model + base_model_output = self.base_model(inputs_embeds=inputs_embeds, **kwargs) + # Calculate the loss using the custom CPT loss function + base_model_output = CPTEmbedding.calculate_loss( + base_model_output, cpt_labels, cpt_type_mask, self.peft_config["default"] + ) + + return base_model_output + def generate(self, *args, **kwargs): peft_config = self.active_peft_config self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation diff --git a/src/peft/tuners/cpt/__init__.py b/src/peft/tuners/cpt/__init__.py index 67b200cce3..f5018f89b1 100644 --- a/src/peft/tuners/cpt/__init__.py +++ b/src/peft/tuners/cpt/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023-present the HuggingFace Inc. team. +# Copyright 2024-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/peft/tuners/cpt/config.py b/src/peft/tuners/cpt/config.py index 98be0c5597..7e9e912910 100644 --- a/src/peft/tuners/cpt/config.py +++ b/src/peft/tuners/cpt/config.py @@ -1,3 +1,17 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import enum from dataclasses import dataclass, field from typing import Optional @@ -8,8 +22,9 @@ from peft.utils import PeftType -class PromptTuningInit(str, enum.Enum): - """Enum for specifying the initialization method for prompt tuning.""" + +class CPTPromptInit(str, enum.Enum): + """Enum for specifying the initialization method for CPT.""" TEXT = "TEXT" # Initialize using text-based embeddings. RANDOM = "RANDOM" # Initialize randomly. @@ -20,21 +35,26 @@ class CPTConfig(PeftConfig): """ CPT Configuration class extending PeftConfig for Context-aware Prompt Tuning (CPT). - This class introduces additional parameters required for CPT, such as token type masks, - prompt tuning initialization, loss weighting, and projection settings. + This class introduces additional parameters required for CPT, such as: + - Token type masks + - Prompt tuning initialization + - Loss weighting + - Projection settings + + For more details, see the paper: https://arxiv.org/abs/2410.17222 """ # Token-related configurations - CPT_token_ids: Optional[torch.Tensor] = field( + cpt_token_ids: Optional[list[int]] = field( default=None, metadata={"help": "Tensor of token IDs used for CPT prompts."} ) - CPT_mask: Optional[torch.Tensor] = field(default=None, metadata={"help": "Tensor mask applied to CPT tokens."}) - CPT_tokens_type_mask: Optional[bool] = field( + cpt_mask: Optional[list[int]] = field(default=None, metadata={"help": "Tensor mask applied to CPT tokens."}) + cpt_tokens_type_mask: Optional[list[int]] = field( default=None, metadata={"help": "Mask indicating the type of each CPT token."} ) # Prompt tuning initialization method - CPT_prompt_tuning_init: Optional[str] = field( + cpt_prompt_tuning_init: Optional[str] = field( default="TEXT", metadata={"help": "Initialization method: 'TEXT' for embedding-based, 'RANDOM' for random."} ) diff --git a/src/peft/tuners/cpt/model.py b/src/peft/tuners/cpt/model.py index dceb30590e..52852786a9 100644 --- a/src/peft/tuners/cpt/model.py +++ b/src/peft/tuners/cpt/model.py @@ -1,3 +1,17 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import copy import torch @@ -5,7 +19,7 @@ from peft.utils.integrations import gather_params_ctx -from .config import PromptTuningInit +from .config import CPTPromptInit class CPTEmbedding(torch.nn.Module): @@ -31,10 +45,10 @@ def __init__(self, config, word_embeddings): self.embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim) # Initialize embeddings using text-based prompt tuning, if configured - if config.CPT_prompt_tuning_init == PromptTuningInit.TEXT and not config.inference_mode: - assert config.num_virtual_tokens == len(config.CPT_token_ids) + if config.cpt_prompt_tuning_init == CPTPromptInit.TEXT and not config.inference_mode: + assert config.num_virtual_tokens == len(config.cpt_token_ids) - init_token_ids = torch.LongTensor(config.CPT_token_ids).to(word_embeddings.weight.device) + init_token_ids = torch.LongTensor(config.cpt_token_ids).to(word_embeddings.weight.device) with gather_params_ctx(word_embeddings.parameters()): word_embedding_weights = word_embeddings(init_token_ids).detach().clone() word_embedding_weights = word_embedding_weights.to(torch.float32) @@ -69,14 +83,14 @@ def set_updated_tokens(self): """ Sets up a backward hook to selectively update token gradients based on the CPT token type mask. """ - if self.config.CPT_prompt_tuning_init == PromptTuningInit.TEXT: - tensor_ICL_mask = torch.Tensor(self.config.CPT_tokens_type_mask).long() + if self.config.cpt_prompt_tuning_init == CPTPromptInit.TEXT: + tensor_ICL_mask = torch.Tensor(self.config.cpt_tokens_type_mask).long() mask_input_template = torch.remainder(tensor_ICL_mask, 4) == 1 mask_input = torch.remainder(tensor_ICL_mask, 4) == 2 mask_output_template = torch.remainder(tensor_ICL_mask, 4) == 3 mask = mask_input_template | mask_input | mask_output_template mask = mask.view(-1, 1) - elif self.config.CPT_prompt_tuning_init == PromptTuningInit.RANDOM: + elif self.config.cpt_prompt_tuning_init == CPTPromptInit.RANDOM: mask = torch.ones((self.config.num_virtual_tokens, 1)).long() def backward_hook(grad): @@ -86,10 +100,10 @@ def backward_hook(grad): self.delta_embedding.weight.register_hook(backward_hook) def get_epsilon(self): - if self.config.CPT_prompt_tuning_init == "TEXT": - CPT_tokens_type_mask = self.config.CPT_tokens_type_mask + if self.config.cpt_prompt_tuning_init == "TEXT": + cpt_tokens_type_mask = self.config.cpt_tokens_type_mask else: - CPT_tokens_type_mask = [2] * self.config.num_virtual_tokens + cpt_tokens_type_mask = [2] * self.config.num_virtual_tokens MIN_VALUE = 1e-10 @@ -104,12 +118,12 @@ def get_epsilon(self): torch.Tensor([self.config.token_dim / 2048]) ) - epsilon = torch.ones_like(torch.Tensor(CPT_tokens_type_mask)).to(torch.float32) * MIN_VALUE - CPT_tokens_type_mask = torch.Tensor(CPT_tokens_type_mask).long() + epsilon = torch.ones_like(torch.Tensor(cpt_tokens_type_mask)).to(torch.float32) * MIN_VALUE + cpt_tokens_type_mask = torch.Tensor(cpt_tokens_type_mask).long() - epsilon[(CPT_tokens_type_mask > 0) & (torch.remainder(CPT_tokens_type_mask, 4) == 1)] = normalized_format_eps - epsilon[(CPT_tokens_type_mask > 0) & (torch.remainder(CPT_tokens_type_mask, 4) == 3)] = normalized_format_eps - epsilon[(CPT_tokens_type_mask > 0) & (torch.remainder(CPT_tokens_type_mask, 4) == 2)] = normalized_input_eps + epsilon[(cpt_tokens_type_mask > 0) & (torch.remainder(cpt_tokens_type_mask, 4) == 1)] = normalized_format_eps + epsilon[(cpt_tokens_type_mask > 0) & (torch.remainder(cpt_tokens_type_mask, 4) == 3)] = normalized_format_eps + epsilon[(cpt_tokens_type_mask > 0) & (torch.remainder(cpt_tokens_type_mask, 4) == 2)] = normalized_input_eps return epsilon @@ -132,14 +146,14 @@ def projection(self): self.delta_embedding.weight.data = new_embeddings_weights @staticmethod - def calculate_loss(base_model_output, labels, CPT_type_mask, config): + def calculate_loss(base_model_output, labels, cpt_type_mask, config): """ Computes the loss for CPT models with optional exponential decay. Args: base_model_output (ModelOutput): Output from the base model containing logits. labels (torch.Tensor): Ground-truth labels for the input tokens. - CPT_type_mask (torch.Tensor): Token type mask used for filtering valid loss terms. + cpt_type_mask (torch.Tensor): Token type mask used for filtering valid loss terms. config (Namespace): Configuration object containing loss-related hyperparameters. Returns: @@ -155,7 +169,7 @@ def calculate_loss(base_model_output, labels, CPT_type_mask, config): # Shift logits and labels for token prediction shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() - shift_CPT_type_mask = CPT_type_mask[..., 1:].contiguous() + shift_cpt_type_mask = cpt_type_mask[..., 1:].contiguous() shift_labels_bool = (shift_labels.clone().detach() != -100).bool() batch_size, seq_length, vocab_size = shift_logits.shape @@ -169,13 +183,13 @@ def calculate_loss(base_model_output, labels, CPT_type_mask, config): # Apply exponential decay weights to the loss shift_labels_weights = shift_labels_bool.clone().detach().float() for i in range(batch_size): - idx_labels = (shift_CPT_type_mask[i] > 0) & (shift_CPT_type_mask[i] % 4 == 0) - labels_ids = shift_CPT_type_mask[i][idx_labels].unique() + idx_labels = (shift_cpt_type_mask[i] > 0) & (shift_cpt_type_mask[i] % 4 == 0) + labels_ids = shift_cpt_type_mask[i][idx_labels].unique() - exponential_decay = torch.ones_like(shift_CPT_type_mask[i]).to(device=device).float() + exponential_decay = torch.ones_like(shift_cpt_type_mask[i]).to(device=device).float() decay_value = 1 for label_mask_idx in torch.flip(labels_ids, [0]): - exponential_decay[shift_CPT_type_mask[i] == label_mask_idx] = decay_value + exponential_decay[shift_cpt_type_mask[i] == label_mask_idx] = decay_value decay_value *= config.opt_loss_decay_factor shift_labels_weights[i] *= exponential_decay @@ -188,20 +202,20 @@ def calculate_loss(base_model_output, labels, CPT_type_mask, config): return base_model_output def check_config(self): - if self.config.CPT_prompt_tuning_init == PromptTuningInit.TEXT: - assert self.config.CPT_token_ids is not None - assert self.config.CPT_mask is not None - assert self.config.CPT_tokens_type_mask is not None + if self.config.cpt_prompt_tuning_init == CPTPromptInit.TEXT: + assert self.config.cpt_token_ids is not None + assert self.config.cpt_mask is not None + assert self.config.cpt_tokens_type_mask is not None assert ( - len(self.config.CPT_token_ids) - == len(self.config.CPT_mask) - == len(self.config.CPT_tokens_type_mask) + len(self.config.cpt_token_ids) + == len(self.config.cpt_mask) + == len(self.config.cpt_tokens_type_mask) == self.config.num_virtual_tokens ) - elif self.config.CPT_prompt_tuning_init == PromptTuningInit.RANDOM: - assert self.config.CPT_token_ids is None - assert self.config.CPT_mask is None - assert self.config.CPT_tokens_type_mask is None + elif self.config.cpt_prompt_tuning_init == CPTPromptInit.RANDOM: + assert self.config.cpt_token_ids is None + assert self.config.cpt_mask is None + assert self.config.cpt_tokens_type_mask is None assert self.config.num_virtual_tokens > 0 else: - raise NotImplementedError(f" was not implemented for {self.config.CPT_prompt_tuning_init}") + raise NotImplementedError(f" was not implemented for {self.config.cpt_prompt_tuning_init}") diff --git a/tests/CPT_test.py b/tests/test_cpt.py similarity index 80% rename from tests/CPT_test.py rename to tests/test_cpt.py index 18862f867c..b6c58ba385 100644 --- a/tests/CPT_test.py +++ b/tests/test_cpt.py @@ -1,3 +1,17 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Any, Dict, List, Union import pytest @@ -33,10 +47,10 @@ def global_tokenizer(): def config_text(): """Load the SST2 dataset and prepare it for testing.""" config = CPTConfig( - CPT_token_ids=[0, 1, 2, 3, 4, 5, 6, 7], # Example token IDs for testing - CPT_mask=[1, 1, 1, 1, 1, 1, 1, 1], - CPT_tokens_type_mask=[1, 2, 2, 2, 3, 3, 3, 4], - CPT_prompt_tuning_init="TEXT", + cpt_token_ids=[0, 1, 2, 3, 4, 5, 6, 7], # Example token IDs for testing + cpt_mask=[1, 1, 1, 1, 1, 1, 1, 1], + cpt_tokens_type_mask=[1, 2, 2, 2, 3, 3, 3, 4], + cpt_prompt_tuning_init="TEXT", num_virtual_tokens=8, opt_weighted_loss_type="decay", opt_loss_decay_factor=0.95, @@ -51,7 +65,7 @@ def config_text(): def config_random(): """Load the SST2 dataset and prepare it for testing.""" config = CPTConfig( - CPT_prompt_tuning_init="RANDOM", + cpt_prompt_tuning_init="RANDOM", num_virtual_tokens=8, opt_weighted_loss_type="decay", opt_loss_decay_factor=0.95, @@ -236,59 +250,16 @@ def test_model_training_random(sst_data, global_tokenizer, collator, config_rand trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=collator) - try: - trainer.train() - except Exception as e: - pytest.fail(f"Training failed with error: {e}") - - assert torch.all(model.prompt_encoder.default.embedding.weight.data.clone().detach().cpu() == emb.cpu()) - - model.prompt_encoder.default.projection() - delta_emb = model.prompt_encoder.default.delta_embedding.weight.data.clone().detach() - norm_delta = delta_emb.norm(dim=1).cpu() - epsilon = model.prompt_encoder.default.get_epsilon().cpu() - assert torch.all(norm_delta <= epsilon) - - -def test_model_training_text(sst_data, global_tokenizer, collator, config_text): - """Perform a short training run to verify the model and data integration.""" - - base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=".", trust_remote_code=True) - model = get_peft_model(base_model, config_text) - emb = model.prompt_encoder.default.embedding.weight.data.clone().detach() - - training_args = TrainingArguments( - output_dir="./results", - per_device_train_batch_size=1, - num_train_epochs=2, - remove_unused_columns=False, - save_strategy="no", - logging_steps=1, - ) - - train_dataset = dataset(sst_data["train"], global_tokenizer) - - trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=collator) - - try: - trainer.train() - except Exception as e: - pytest.fail(f"Training failed with error: {e}") - + trainer.train() + # Verify that the embedding tensor remains unchanged (frozen) assert torch.all(model.prompt_encoder.default.embedding.weight.data.clone().detach().cpu() == emb.cpu()) - delta_emb = model.prompt_encoder.default.delta_embedding.weight.data.clone().detach() - norm_delta = delta_emb.norm(dim=1).cpu() - CPT_tokens_type_mask = torch.Tensor(config_text.CPT_tokens_type_mask).long() - non_label_idx = (CPT_tokens_type_mask == 1) | (CPT_tokens_type_mask == 2) | (CPT_tokens_type_mask == 3) - assert torch.all((norm_delta > 0) == non_label_idx) - model.prompt_encoder.default.projection() delta_emb = model.prompt_encoder.default.delta_embedding.weight.data.clone().detach() norm_delta = delta_emb.norm(dim=1).cpu() epsilon = model.prompt_encoder.default.get_epsilon().cpu() + # Verify that the change in tokens is constrained to epsilon assert torch.all(norm_delta <= epsilon) - assert torch.all((norm_delta == 0) == (~non_label_idx)) def test_model_batch_training_text(sst_data, global_tokenizer, collator, config_text): @@ -311,22 +282,18 @@ def test_model_batch_training_text(sst_data, global_tokenizer, collator, config_ trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=collator) - try: - trainer.train() - except Exception as e: - pytest.fail(f"Training failed with error: {e}") - + trainer.train() + # Verify that the embedding tensor remains unchanged (frozen) assert torch.all(model.prompt_encoder.default.embedding.weight.data.clone().detach().cpu() == emb.cpu()) - delta_emb = model.prompt_encoder.default.delta_embedding.weight.data.clone().detach() - norm_delta = delta_emb.norm(dim=1).cpu() - CPT_tokens_type_mask = torch.Tensor(config_text.CPT_tokens_type_mask).long() - non_label_idx = (CPT_tokens_type_mask == 1) | (CPT_tokens_type_mask == 2) | (CPT_tokens_type_mask == 3) - assert torch.all((norm_delta > 0) == non_label_idx) + cpt_tokens_type_mask = torch.Tensor(config_text.cpt_tokens_type_mask).long() + non_label_idx = (cpt_tokens_type_mask == 1) | (cpt_tokens_type_mask == 2) | (cpt_tokens_type_mask == 3) model.prompt_encoder.default.projection() delta_emb = model.prompt_encoder.default.delta_embedding.weight.data.clone().detach() norm_delta = delta_emb.norm(dim=1).cpu() epsilon = model.prompt_encoder.default.get_epsilon().cpu() + # Verify that the change in tokens is constrained to epsilon assert torch.all(norm_delta <= epsilon) + # Ensure that label tokens remain unchanged assert torch.all((norm_delta == 0) == (~non_label_idx)) diff --git a/tests/testing_common.py b/tests/testing_common.py index 954f79be5f..e8e511b19c 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -52,6 +52,7 @@ get_peft_model_state_dict, inject_adapter_in_model, prepare_model_for_kbit_training, + CPTConfig ) from peft.tuners.lora import LoraLayer from peft.utils import _get_submodules, infer_device @@ -119,6 +120,11 @@ { "target_modules": None, }, + # CPT tuninig + { + "num_virtual_tokens": 8, + }, + ) CLASSES_MAPPING = { @@ -134,6 +140,7 @@ "hra": (HRAConfig, CONFIG_TESTING_KWARGS[9]), "vblora": (VBLoRAConfig, CONFIG_TESTING_KWARGS[10]), "oft": (OFTConfig, CONFIG_TESTING_KWARGS[11]), + "cpt": (CPTConfig, CONFIG_TESTING_KWARGS[12]), } From bd2fc70ee3b5b0f5adf2bc2878c95353cee74497 Mon Sep 17 00:00:00 2001 From: tsachiblau Date: Wed, 30 Oct 2024 22:41:31 +0100 Subject: [PATCH 03/14] config: Added config check in __post_init__. Removed redundant initialization in config. Renamed cpt_prompt_tuning_init to cpt_prompt_init. Changed the class from PeftConfig to PromptLearningConfig. model: Removed check_config function. peft_model: Fixed bugs. tests: Added PeftTestConfigManagerForDecoderModels in test_decoder_models.py and testing_common.py. --- src/peft/peft_model.py | 27 ++++++++++++----- src/peft/tuners/cpt/config.py | 55 +++++++++++++++++++++++++++++++---- src/peft/tuners/cpt/model.py | 27 +++-------------- tests/test_cpt.py | 4 +-- tests/test_decoder_models.py | 3 +- tests/testing_common.py | 26 +++++++++++++++-- 6 files changed, 98 insertions(+), 44 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index eca7953b1d..8415fbd1b5 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1740,10 +1740,19 @@ def forward( def _cpt_forward(self, input_ids=None, inputs_embeds=None, peft_config=None, task_ids=None, batch_size=None, **kwargs): # Extract labels from kwargs labels = kwargs.pop("labels") + device = [i.device for i in[input_ids, inputs_embeds, labels] if i is not None][0] # Extract input_type_mask from kwargs and move it to the same device as labels - input_type_mask = kwargs.pop("input_type_mask").to(labels.device) + if 'input_type_mask' in kwargs.keys(): + input_type_mask = kwargs.pop("input_type_mask").to(device) + else: + if input_ids is None: + N_tokens = inputs_embeds.shape[1] + else: + N_tokens = input_ids.shape[1] + input_type_mask = torch.zeros((batch_size, N_tokens)).to(device) + input_type_mask[:, -1] = 4 - if peft_config.cpt_prompt_tuning_init == "TEXT": + if peft_config.cpt_prompt_init == "TEXT": cpt_token_ids = peft_config.cpt_token_ids cpt_tokens_type_mask = peft_config.cpt_tokens_type_mask else: @@ -1777,12 +1786,14 @@ def _cpt_forward(self, input_ids=None, inputs_embeds=None, peft_config=None, tas kwargs["labels"] = cpt_labels # Pass the modified inputs to the base model base_model_output = self.base_model(inputs_embeds=inputs_embeds, **kwargs) - # Calculate the loss using the custom CPT loss function - base_model_output = CPTEmbedding.calculate_loss( - base_model_output, cpt_labels, cpt_type_mask, self.peft_config["default"] - ) - - return base_model_output + if labels is None: + return base_model_output + else: + # Calculate the loss using the custom CPT loss function + base_model_output = CPTEmbedding.calculate_loss( + base_model_output, cpt_labels, cpt_type_mask, self.peft_config["default"] + ) + return base_model_output def generate(self, *args, **kwargs): peft_config = self.active_peft_config diff --git a/src/peft/tuners/cpt/config.py b/src/peft/tuners/cpt/config.py index 7e9e912910..4d5ee105a6 100644 --- a/src/peft/tuners/cpt/config.py +++ b/src/peft/tuners/cpt/config.py @@ -16,9 +16,7 @@ from dataclasses import dataclass, field from typing import Optional -import torch - -from peft.config import PeftConfig +from peft.config import PromptLearningConfig from peft.utils import PeftType @@ -31,7 +29,7 @@ class CPTPromptInit(str, enum.Enum): @dataclass -class CPTConfig(PeftConfig): +class CPTConfig(PromptLearningConfig): """ CPT Configuration class extending PeftConfig for Context-aware Prompt Tuning (CPT). @@ -54,7 +52,7 @@ class CPTConfig(PeftConfig): ) # Prompt tuning initialization method - cpt_prompt_tuning_init: Optional[str] = field( + cpt_prompt_init: Optional[str] = field( default="TEXT", metadata={"help": "Initialization method: 'TEXT' for embedding-based, 'RANDOM' for random."} ) @@ -98,5 +96,50 @@ def __post_init__(self): Post-initialization hook to set additional attributes after the config is initialized. """ self.peft_type = PeftType.CPT # Specifies that the PEFT type is CPT. - self.target_modules = None # Placeholder for target modules in CPT. self.task_type = "CAUSAL_LM" # Ensures task type is causal language modeling. + + if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_token_ids is None: + raise ValueError( + f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " + f"cpt_token_ids can't be None." + ) + if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_mask is None: + raise ValueError( + f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " + f"cpt_mask can't be None." + ) + if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_tokens_type_mask is None: + raise ValueError( + f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " + f"cpt_tokens_type_mask can't be None." + ) + if (self.cpt_prompt_init == CPTPromptInit.TEXT) and not (len(self.cpt_token_ids) == len(self.cpt_mask) == len(self.cpt_tokens_type_mask) == self.num_virtual_tokens): + raise ValueError( + f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " + f"cpt_token_ids, cpt_mask and cpt_tokens_type_mask must have the same length." + ) + + if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.cpt_token_ids is not None: + raise ValueError( + f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " + f"cpt_token_ids must be None." + ) + if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.cpt_mask is not None: + raise ValueError( + f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " + f"cpt_mask must be None." + ) + if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.cpt_tokens_type_mask is not None: + raise ValueError( + f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " + f"cpt_tokens_type_mask must be None." + ) + if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.num_virtual_tokens == 0: + raise ValueError( + f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " + f"num_virtual_tokens must be greater than zero." + ) + if (self.cpt_prompt_init != CPTPromptInit.RANDOM) and (self.cpt_prompt_init != CPTPromptInit.TEXT): + raise ValueError( + f"prompt_tuning_init must be 'RANDOM' or 'TEXT'" + ) diff --git a/src/peft/tuners/cpt/model.py b/src/peft/tuners/cpt/model.py index 52852786a9..2072f3fde7 100644 --- a/src/peft/tuners/cpt/model.py +++ b/src/peft/tuners/cpt/model.py @@ -38,14 +38,13 @@ def __init__(self, config, word_embeddings): """ super().__init__() self.config = copy.deepcopy(config) - self.check_config() num_virtual_tokens = config.num_virtual_tokens # Initialize embeddings with virtual token dimensions self.embedding = torch.nn.Embedding(num_virtual_tokens, config.token_dim) # Initialize embeddings using text-based prompt tuning, if configured - if config.cpt_prompt_tuning_init == CPTPromptInit.TEXT and not config.inference_mode: + if config.cpt_prompt_init == CPTPromptInit.TEXT and not config.inference_mode: assert config.num_virtual_tokens == len(config.cpt_token_ids) init_token_ids = torch.LongTensor(config.cpt_token_ids).to(word_embeddings.weight.device) @@ -83,14 +82,14 @@ def set_updated_tokens(self): """ Sets up a backward hook to selectively update token gradients based on the CPT token type mask. """ - if self.config.cpt_prompt_tuning_init == CPTPromptInit.TEXT: + if self.config.cpt_prompt_init == CPTPromptInit.TEXT: tensor_ICL_mask = torch.Tensor(self.config.cpt_tokens_type_mask).long() mask_input_template = torch.remainder(tensor_ICL_mask, 4) == 1 mask_input = torch.remainder(tensor_ICL_mask, 4) == 2 mask_output_template = torch.remainder(tensor_ICL_mask, 4) == 3 mask = mask_input_template | mask_input | mask_output_template mask = mask.view(-1, 1) - elif self.config.cpt_prompt_tuning_init == CPTPromptInit.RANDOM: + elif self.config.cpt_prompt_init == CPTPromptInit.RANDOM: mask = torch.ones((self.config.num_virtual_tokens, 1)).long() def backward_hook(grad): @@ -100,7 +99,7 @@ def backward_hook(grad): self.delta_embedding.weight.register_hook(backward_hook) def get_epsilon(self): - if self.config.cpt_prompt_tuning_init == "TEXT": + if self.config.cpt_prompt_init == "TEXT": cpt_tokens_type_mask = self.config.cpt_tokens_type_mask else: cpt_tokens_type_mask = [2] * self.config.num_virtual_tokens @@ -201,21 +200,3 @@ def calculate_loss(base_model_output, labels, cpt_type_mask, config): return base_model_output - def check_config(self): - if self.config.cpt_prompt_tuning_init == CPTPromptInit.TEXT: - assert self.config.cpt_token_ids is not None - assert self.config.cpt_mask is not None - assert self.config.cpt_tokens_type_mask is not None - assert ( - len(self.config.cpt_token_ids) - == len(self.config.cpt_mask) - == len(self.config.cpt_tokens_type_mask) - == self.config.num_virtual_tokens - ) - elif self.config.cpt_prompt_tuning_init == CPTPromptInit.RANDOM: - assert self.config.cpt_token_ids is None - assert self.config.cpt_mask is None - assert self.config.cpt_tokens_type_mask is None - assert self.config.num_virtual_tokens > 0 - else: - raise NotImplementedError(f" was not implemented for {self.config.cpt_prompt_tuning_init}") diff --git a/tests/test_cpt.py b/tests/test_cpt.py index b6c58ba385..a351e313a0 100644 --- a/tests/test_cpt.py +++ b/tests/test_cpt.py @@ -50,7 +50,7 @@ def config_text(): cpt_token_ids=[0, 1, 2, 3, 4, 5, 6, 7], # Example token IDs for testing cpt_mask=[1, 1, 1, 1, 1, 1, 1, 1], cpt_tokens_type_mask=[1, 2, 2, 2, 3, 3, 3, 4], - cpt_prompt_tuning_init="TEXT", + cpt_prompt_init="TEXT", num_virtual_tokens=8, opt_weighted_loss_type="decay", opt_loss_decay_factor=0.95, @@ -65,7 +65,7 @@ def config_text(): def config_random(): """Load the SST2 dataset and prepare it for testing.""" config = CPTConfig( - cpt_prompt_tuning_init="RANDOM", + cpt_prompt_init="RANDOM", num_virtual_tokens=8, opt_weighted_loss_type="decay", opt_loss_decay_factor=0.95, diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 3ad373ac01..2f6a78995f 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -39,8 +39,7 @@ get_peft_model, ) -from .testing_common import PeftCommonTester, PeftTestConfigManager - +from .testing_common import PeftCommonTester, PeftTestConfigManagerForDecoderModels as PeftTestConfigManager PEFT_DECODER_MODELS_TO_TEST = [ "hf-internal-testing/tiny-random-OPTForCausalLM", diff --git a/tests/testing_common.py b/tests/testing_common.py index e8e511b19c..0a80b25ba4 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -123,6 +123,10 @@ # CPT tuninig { "num_virtual_tokens": 8, + "cpt_token_ids": [0, 1, 2, 3, 4, 5, 6, 7], # Example token IDs for testing + "cpt_mask": [1, 1, 1, 1, 1, 1, 1, 1], + "cpt_tokens_type_mask": [1, 2, 2, 2, 3, 3, 4, 4], + "cpt_prompt_init": "TEXT", }, ) @@ -140,9 +144,11 @@ "hra": (HRAConfig, CONFIG_TESTING_KWARGS[9]), "vblora": (VBLoRAConfig, CONFIG_TESTING_KWARGS[10]), "oft": (OFTConfig, CONFIG_TESTING_KWARGS[11]), - "cpt": (CPTConfig, CONFIG_TESTING_KWARGS[12]), } +DECODER_MODELS_EXTRA = { + "cpt": (CPTConfig, CONFIG_TESTING_KWARGS[12]) +} # Adapted from https://github.com/huggingface/transformers/blob/48327c57182fdade7f7797d1eaad2d166de5c55b/src/transformers/activations.py#LL166C7-L166C22 class ClassInstantier(OrderedDict): @@ -207,7 +213,7 @@ def get_grid_parameters(self, grid_parameters, filter_params_func=None): PeftTestConfigManager = ClassInstantier(CLASSES_MAPPING) - +PeftTestConfigManagerForDecoderModels = ClassInstantier({**CLASSES_MAPPING, **DECODER_MODELS_EXTRA}) class PeftCommonTester: r""" @@ -1196,10 +1202,24 @@ def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwar loss = output.sum() loss.backward() + if issubclass(config_cls, CPTConfig): + parameters = [] + list_names = [] + for name, param in model.prompt_encoder.named_parameters(): + if name not in ['default.embedding.weight']: + parameters.append(param) + list_names.append(name) + else: + assert param.grad is None + '' + else: + parameters = model.prompt_encoder.parameters() + # check that prompt encoder has grads - for param in model.prompt_encoder.parameters(): + for param in parameters: assert param.grad is not None + def _test_delete_adapter(self, model_id, config_cls, config_kwargs): supported_peft_types = [ PeftType.LORA, From 77bb0b953e6094052d8648cf3c5ad93be26077c5 Mon Sep 17 00:00:00 2001 From: tsachiblau Date: Sun, 3 Nov 2024 10:14:59 +0100 Subject: [PATCH 04/14] tests: Updated test_cpt and testing_common as per the PR requirements. --- tests/test_cpt.py | 12 ++++++------ tests/testing_common.py | 17 +++++------------ 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/tests/test_cpt.py b/tests/test_cpt.py index a351e313a0..891855c8f3 100644 --- a/tests/test_cpt.py +++ b/tests/test_cpt.py @@ -32,7 +32,7 @@ TEMPLATE = {"input": "input: {}", "intra_seperator": " ", "output": "output: {}", "inter_seperator": "\n"} -MODEL_NAME = "bigscience/bloom-1b7" +MODEL_NAME = "hf-internal-testing/tiny-random-OPTForCausalLM" MAX_INPUT_LENGTH = 1024 @@ -40,7 +40,7 @@ def global_tokenizer(): """Load the tokenizer fixture for the model.""" - return AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=".", padding_side="right", trust_remote_code=True) + return AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right") @pytest.fixture(scope="module") @@ -217,7 +217,7 @@ def __getitem__(self, idx): def test_model_initialization_text(global_tokenizer, config_text): """Test model loading and PEFT model initialization.""" - base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=".", trust_remote_code=True) + base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) model = get_peft_model(base_model, config_text) assert model is not None, "PEFT model initialization failed" @@ -225,7 +225,7 @@ def test_model_initialization_text(global_tokenizer, config_text): def test_model_initialization_random(global_tokenizer, config_random): """Test model loading and PEFT model initialization.""" - base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=".", trust_remote_code=True) + base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) model = get_peft_model(base_model, config_random) assert model is not None, "PEFT model initialization failed" @@ -234,7 +234,7 @@ def test_model_initialization_random(global_tokenizer, config_random): def test_model_training_random(sst_data, global_tokenizer, collator, config_random): """Perform a short training run to verify the model and data integration.""" - base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=".", trust_remote_code=True) + base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) model = get_peft_model(base_model, config_random) emb = model.prompt_encoder.default.embedding.weight.data.clone().detach() training_args = TrainingArguments( @@ -265,7 +265,7 @@ def test_model_training_random(sst_data, global_tokenizer, collator, config_rand def test_model_batch_training_text(sst_data, global_tokenizer, collator, config_text): """Perform a short training run to verify the model and data integration.""" - base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=".", trust_remote_code=True) + base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) model = get_peft_model(base_model, config_text) emb = model.prompt_encoder.default.embedding.weight.data.clone().detach() diff --git a/tests/testing_common.py b/tests/testing_common.py index d3fbd6c6e1..17ebf53577 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -32,6 +32,7 @@ from peft import ( AdaLoraConfig, BOFTConfig, + CPTConfig, FourierFTConfig, HRAConfig, IA3Config, @@ -52,7 +53,6 @@ get_peft_model_state_dict, inject_adapter_in_model, prepare_model_for_kbit_training, - CPTConfig ) from peft.tuners.lora import LoraLayer from peft.utils import _get_submodules, infer_device @@ -128,7 +128,6 @@ "cpt_tokens_type_mask": [1, 2, 2, 2, 3, 3, 4, 4], "cpt_prompt_init": "TEXT", }, - ) CLASSES_MAPPING = { @@ -146,9 +145,8 @@ "oft": (OFTConfig, CONFIG_TESTING_KWARGS[11]), } -DECODER_MODELS_EXTRA = { - "cpt": (CPTConfig, CONFIG_TESTING_KWARGS[12]) -} +DECODER_MODELS_EXTRA = {"cpt": (CPTConfig, CONFIG_TESTING_KWARGS[12])} + # Adapted from https://github.com/huggingface/transformers/blob/48327c57182fdade7f7797d1eaad2d166de5c55b/src/transformers/activations.py#LL166C7-L166C22 class ClassInstantier(OrderedDict): @@ -215,6 +213,7 @@ def get_grid_parameters(self, grid_parameters, filter_params_func=None): PeftTestConfigManager = ClassInstantier(CLASSES_MAPPING) PeftTestConfigManagerForDecoderModels = ClassInstantier({**CLASSES_MAPPING, **DECODER_MODELS_EXTRA}) + class PeftCommonTester: r""" A large testing suite for testing common functionality of the PEFT models. @@ -1206,14 +1205,9 @@ def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwar if issubclass(config_cls, CPTConfig): parameters = [] - list_names = [] for name, param in model.prompt_encoder.named_parameters(): - if name not in ['default.embedding.weight']: + if name != "default.embedding.weight": parameters.append(param) - list_names.append(name) - else: - assert param.grad is None - '' else: parameters = model.prompt_encoder.parameters() @@ -1221,7 +1215,6 @@ def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwar for param in parameters: assert param.grad is not None - def _test_delete_adapter(self, model_id, config_cls, config_kwargs): supported_peft_types = [ PeftType.LORA, From dbcdedfec9c5c883684f56d2992a8623799386b3 Mon Sep 17 00:00:00 2001 From: tsachiblau Date: Sun, 3 Nov 2024 13:40:46 +0100 Subject: [PATCH 05/14] Created cpt.md in package_regerence. Updated the prompting.md file. added into _toctree.yml. --- docs/source/_toctree.yml | 3 +++ docs/source/conceptual_guides/prompting.md | 16 +++++++++++ docs/source/package_reference/cpt.md | 31 ++++++++++++++++++++++ 3 files changed, 50 insertions(+) create mode 100644 docs/source/package_reference/cpt.md diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index fe66a3d6c4..5b14623ec9 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -118,6 +118,9 @@ title: VB-LoRA - local: package_reference/hra title: HRA + - local: package_reference/cpt + title: CPT + title: Adapters - sections: diff --git a/docs/source/conceptual_guides/prompting.md b/docs/source/conceptual_guides/prompting.md index 810222c4b3..ccdda6559a 100644 --- a/docs/source/conceptual_guides/prompting.md +++ b/docs/source/conceptual_guides/prompting.md @@ -75,3 +75,19 @@ Take a look at [P-tuning for sequence classification](../task_guides/ptuning-seq Prompt decomposition. + + +## Context-Aware Prompt Tuning (CPT) + +
+ +
+CPT optimizing only specific token embeddings while keeping the rest of the model frozen (image source). + +[Context-Aware Prompt Tuning (CPT)](https://huggingface.co/papers/2410.17222) is designed to enhance few-shot classification by refining only context embeddings. +This approach combines ideas from In-Context Learning (ICL), Prompt Tuning (PT), and adversarial optimization, focusing on making model adaptation both parameter-efficient and effective. +In CPT, only specific context token embeddings are optimized, while the rest of the model remains frozen. +To prevent overfitting and maintain stability, CPT uses controlled perturbations to limit the allowed changes to context embeddings within a defined range. +Additionally, to address the phenomenon of recency bias—where examples near the end of the context tend to be prioritized over earlier ones—CPT applies a decay loss factor. + +Take a look at [Context-Aware Prompt Tuning for few-shot classification](../task_guides/cpt-few-shot-classification) for a step-by-step guide on how to train a model with CPT. diff --git a/docs/source/package_reference/cpt.md b/docs/source/package_reference/cpt.md new file mode 100644 index 0000000000..28ad9721c1 --- /dev/null +++ b/docs/source/package_reference/cpt.md @@ -0,0 +1,31 @@ + + +# Context-Aware Prompt Tuning (CPT) + +[Context-aware Prompt Tuning: Advancing In-Context Learning with Adversarial Methods (CPT)](https://huggingface.co/papers/2410.17222) combines In-Context Learning (ICL) with Prompt Tuning (PT) and adversarial optimization to improve few-shot learning by refining context embeddings. CPT optimizes only context tokens, which minimizes overfitting and enhances performance on classification tasks. + +The abstract from the paper is: + +*Traditional fine-tuning is effective but computationally intensive, as it requires updating billions of parameters. CPT, inspired by ICL, PT, and adversarial attacks, refines context embeddings in a parameter-efficient manner. By optimizing context tokens and applying a controlled gradient descent, CPT achieves superior accuracy across various few-shot classification tasks, showing significant improvement over existing methods such as LoRA, PT, and ICL.* + +## CPTConfig + +[[autodoc]] tuners.cpt.config.CPTConfig + +## CPTModel + +[[autodoc]] tuners.cpt.model.CPTModel + From 0a5fb208f8db314a2d143ff0e053b999d656aa91 Mon Sep 17 00:00:00 2001 From: tsachiblau Date: Tue, 5 Nov 2024 09:02:43 +0100 Subject: [PATCH 06/14] verifying that the model is causal LM --- src/peft/peft_model.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 758d9a87fe..049f51335f 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -655,6 +655,8 @@ def _setup_prompt_encoder(self, adapter_name: str): raise ValueError("Prefix tuning does not work with gradient checkpointing.") prompt_encoder = PrefixEncoder(config) elif config.peft_type == PeftType.CPT: + if not self.base_model.config.is_decoder: + raise ValueError("CPT works only with causal LM models.") prompt_encoder = CPTEmbedding(config, self.word_embeddings) else: raise ValueError("Not supported") @@ -1762,13 +1764,14 @@ def forward( inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) return self.base_model(inputs_embeds=inputs_embeds, **kwargs) - - def _cpt_forward(self, input_ids=None, inputs_embeds=None, peft_config=None, task_ids=None, batch_size=None, **kwargs): + def _cpt_forward( + self, input_ids=None, inputs_embeds=None, peft_config=None, task_ids=None, batch_size=None, **kwargs + ): # Extract labels from kwargs labels = kwargs.pop("labels") - device = [i.device for i in[input_ids, inputs_embeds, labels] if i is not None][0] + device = [i.device for i in [input_ids, inputs_embeds, labels] if i is not None][0] # Extract input_type_mask from kwargs and move it to the same device as labels - if 'input_type_mask' in kwargs.keys(): + if "input_type_mask" in kwargs.keys(): input_type_mask = kwargs.pop("input_type_mask").to(device) else: if input_ids is None: From 7206db512ff046c24b145a65f28b609f7d0753ae Mon Sep 17 00:00:00 2001 From: tsachiblau Date: Tue, 5 Nov 2024 12:40:58 +0100 Subject: [PATCH 07/14] Changed CPTModel to CPTEmbedding --- docs/source/package_reference/cpt.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/package_reference/cpt.md b/docs/source/package_reference/cpt.md index 28ad9721c1..7ad7240ed0 100644 --- a/docs/source/package_reference/cpt.md +++ b/docs/source/package_reference/cpt.md @@ -25,7 +25,7 @@ The abstract from the paper is: [[autodoc]] tuners.cpt.config.CPTConfig -## CPTModel +## CPTEmbedding -[[autodoc]] tuners.cpt.model.CPTModel +[[autodoc]] tuners.cpt.model.CPTEmbedding From 81ffa098a04b56fa9c1034e153829ce575d39ba4 Mon Sep 17 00:00:00 2001 From: tsachiblau Date: Thu, 7 Nov 2024 21:17:19 +0100 Subject: [PATCH 08/14] make style --- tests/testing_common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/testing_common.py b/tests/testing_common.py index 538f775830..1a3015c045 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -134,7 +134,6 @@ "cpt_tokens_type_mask": [1, 2, 2, 2, 3, 3, 4, 4], "cpt_prompt_init": "TEXT", }, - ) CLASSES_MAPPING = { From 130ec76b900b3f166e611cf93f35b29f882bb98f Mon Sep 17 00:00:00 2001 From: tsachiblau Date: Thu, 7 Nov 2024 21:56:27 +0100 Subject: [PATCH 09/14] make style --- src/peft/tuners/cpt/config.py | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/src/peft/tuners/cpt/config.py b/src/peft/tuners/cpt/config.py index 4d5ee105a6..14cbfbf8ba 100644 --- a/src/peft/tuners/cpt/config.py +++ b/src/peft/tuners/cpt/config.py @@ -20,7 +20,6 @@ from peft.utils import PeftType - class CPTPromptInit(str, enum.Enum): """Enum for specifying the initialization method for CPT.""" @@ -99,21 +98,16 @@ def __post_init__(self): self.task_type = "CAUSAL_LM" # Ensures task type is causal language modeling. if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_token_ids is None: - raise ValueError( - f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " - f"cpt_token_ids can't be None." - ) + raise ValueError(f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " f"cpt_token_ids can't be None.") if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_mask is None: - raise ValueError( - f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " - f"cpt_mask can't be None." - ) + raise ValueError(f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " f"cpt_mask can't be None.") if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_tokens_type_mask is None: raise ValueError( - f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " - f"cpt_tokens_type_mask can't be None." + f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " f"cpt_tokens_type_mask can't be None." ) - if (self.cpt_prompt_init == CPTPromptInit.TEXT) and not (len(self.cpt_token_ids) == len(self.cpt_mask) == len(self.cpt_tokens_type_mask) == self.num_virtual_tokens): + if (self.cpt_prompt_init == CPTPromptInit.TEXT) and not ( + len(self.cpt_token_ids) == len(self.cpt_mask) == len(self.cpt_tokens_type_mask) == self.num_virtual_tokens + ): raise ValueError( f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " f"cpt_token_ids, cpt_mask and cpt_tokens_type_mask must have the same length." @@ -121,18 +115,13 @@ def __post_init__(self): if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.cpt_token_ids is not None: raise ValueError( - f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " - f"cpt_token_ids must be None." + f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " f"cpt_token_ids must be None." ) if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.cpt_mask is not None: - raise ValueError( - f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " - f"cpt_mask must be None." - ) + raise ValueError(f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " f"cpt_mask must be None.") if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.cpt_tokens_type_mask is not None: raise ValueError( - f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " - f"cpt_tokens_type_mask must be None." + f"When prompt_tuning_init='{CPTPromptInit.RANDOM.value}', " f"cpt_tokens_type_mask must be None." ) if (self.cpt_prompt_init == CPTPromptInit.RANDOM) and self.num_virtual_tokens == 0: raise ValueError( @@ -140,6 +129,4 @@ def __post_init__(self): f"num_virtual_tokens must be greater than zero." ) if (self.cpt_prompt_init != CPTPromptInit.RANDOM) and (self.cpt_prompt_init != CPTPromptInit.TEXT): - raise ValueError( - f"prompt_tuning_init must be 'RANDOM' or 'TEXT'" - ) + raise ValueError("prompt_tuning_init must be 'RANDOM' or 'TEXT'") From 70067d827dd9699f6dbe910f88c2b388e5c96abe Mon Sep 17 00:00:00 2001 From: tsachiblau Date: Thu, 7 Nov 2024 22:11:50 +0100 Subject: [PATCH 10/14] make style --- src/peft/tuners/cpt/model.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/peft/tuners/cpt/model.py b/src/peft/tuners/cpt/model.py index 2072f3fde7..da18823219 100644 --- a/src/peft/tuners/cpt/model.py +++ b/src/peft/tuners/cpt/model.py @@ -107,11 +107,8 @@ def get_epsilon(self): MIN_VALUE = 1e-10 # Calculate normalized epsilon values for input, output, and format tokens - normalized_format_eps = ( - self.config.opt_projection_format_epsilon - * torch.sqrt( - torch.Tensor([self.config.token_dim / 2048]) - ) + normalized_format_eps = self.config.opt_projection_format_epsilon * torch.sqrt( + torch.Tensor([self.config.token_dim / 2048]) ) normalized_input_eps = self.config.opt_projection_epsilon * torch.sqrt( torch.Tensor([self.config.token_dim / 2048]) @@ -199,4 +196,3 @@ def calculate_loss(base_model_output, labels, cpt_type_mask, config): raise NotImplementedError(f"Loss type '{config.opt_weighted_loss_type}' not implemented.") return base_model_output - From 93973144cf2dccfa6797ef07ebe319b1f9b24d88 Mon Sep 17 00:00:00 2001 From: tsachiblau Date: Fri, 8 Nov 2024 07:09:37 +0100 Subject: [PATCH 11/14] make doc --- src/peft/tuners/cpt/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/cpt/model.py b/src/peft/tuners/cpt/model.py index da18823219..7ab33e7ceb 100644 --- a/src/peft/tuners/cpt/model.py +++ b/src/peft/tuners/cpt/model.py @@ -24,8 +24,8 @@ class CPTEmbedding(torch.nn.Module): """ - CPTEmbedding is a custom embedding layer designed for Context-aware Prompt Tuning (CPT) in PEFT. - It initializes embeddings, applies prompt-specific projections, and computes loss using label masks. + CPTEmbedding is a custom embedding layer designed for Context-aware Prompt Tuning (CPT) in PEFT. It initializes + embeddings, applies prompt-specific projections, and computes loss using label masks. """ def __init__(self, config, word_embeddings): From 0a434730600dd35feaacb88aeee65bf4d36a8477 Mon Sep 17 00:00:00 2001 From: tsachiblau Date: Sun, 10 Nov 2024 20:21:58 +0100 Subject: [PATCH 12/14] Removed redundant checks --- src/peft/peft_model.py | 2 -- src/peft/tuners/cpt/config.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 7ca1a70ad1..d7941d2733 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -657,8 +657,6 @@ def _setup_prompt_encoder(self, adapter_name: str): raise ValueError("Prefix tuning does not work with gradient checkpointing.") prompt_encoder = PrefixEncoder(config) elif config.peft_type == PeftType.CPT: - if not self.base_model.config.is_decoder: - raise ValueError("CPT works only with causal LM models.") prompt_encoder = CPTEmbedding(config, self.word_embeddings) else: raise ValueError("Not supported") diff --git a/src/peft/tuners/cpt/config.py b/src/peft/tuners/cpt/config.py index 14cbfbf8ba..34002af8e8 100644 --- a/src/peft/tuners/cpt/config.py +++ b/src/peft/tuners/cpt/config.py @@ -97,8 +97,6 @@ def __post_init__(self): self.peft_type = PeftType.CPT # Specifies that the PEFT type is CPT. self.task_type = "CAUSAL_LM" # Ensures task type is causal language modeling. - if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_token_ids is None: - raise ValueError(f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " f"cpt_token_ids can't be None.") if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_mask is None: raise ValueError(f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " f"cpt_mask can't be None.") if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_tokens_type_mask is None: From 144f042bfbcb20d26ec9f359bccdeeac601e92e1 Mon Sep 17 00:00:00 2001 From: tsachiblau Date: Wed, 13 Nov 2024 06:21:39 +0100 Subject: [PATCH 13/14] Fixed errors --- src/peft/peft_model.py | 6 ++++-- src/peft/tuners/cpt/config.py | 12 ++++++++---- tests/testing_common.py | 5 ++++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index d7941d2733..8cb755a2c7 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1779,7 +1779,7 @@ def _cpt_forward( else: N_tokens = input_ids.shape[1] input_type_mask = torch.zeros((batch_size, N_tokens)).to(device) - input_type_mask[:, -1] = 4 + input_type_mask[:, :] = 4 if peft_config.cpt_prompt_init == "TEXT": cpt_token_ids = peft_config.cpt_token_ids @@ -1796,6 +1796,7 @@ def _cpt_forward( prompts = prompts.to(inputs_embeds.dtype) inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) # If labels are provided, generate prefix labels and type mask + cpt_labels = None if labels is not None: # Generate prefix labels and concatenate with the input labels prefix_labels = torch.Tensor(cpt_token_ids).long().view(1, -1) @@ -1812,7 +1813,8 @@ def _cpt_forward( labels_idx = (cpt_type_mask > 0) & (cpt_type_mask % 4 == 0) cpt_labels[~labels_idx] = -100 # Update kwargs with the modified labels - kwargs["labels"] = cpt_labels + + kwargs["labels"] = cpt_labels # Pass the modified inputs to the base model base_model_output = self.base_model(inputs_embeds=inputs_embeds, **kwargs) if labels is None: diff --git a/src/peft/tuners/cpt/config.py b/src/peft/tuners/cpt/config.py index 34002af8e8..bc2605d5af 100644 --- a/src/peft/tuners/cpt/config.py +++ b/src/peft/tuners/cpt/config.py @@ -97,12 +97,16 @@ def __post_init__(self): self.peft_type = PeftType.CPT # Specifies that the PEFT type is CPT. self.task_type = "CAUSAL_LM" # Ensures task type is causal language modeling. + if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_token_ids is None: + self.cpt_token_ids = [0] + self.num_virtual_tokens = 1 + if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_mask is None: - raise ValueError(f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " f"cpt_mask can't be None.") + self.cpt_mask = [1 for _ in self.cpt_token_ids] + if (self.cpt_prompt_init == CPTPromptInit.TEXT) and self.cpt_tokens_type_mask is None: - raise ValueError( - f"When prompt_tuning_init='{CPTPromptInit.TEXT.value}', " f"cpt_tokens_type_mask can't be None." - ) + self.cpt_tokens_type_mask = [1 for _ in self.cpt_token_ids] + if (self.cpt_prompt_init == CPTPromptInit.TEXT) and not ( len(self.cpt_token_ids) == len(self.cpt_mask) == len(self.cpt_tokens_type_mask) == self.num_virtual_tokens ): diff --git a/tests/testing_common.py b/tests/testing_common.py index 1a3015c045..b99bb27d13 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -1164,7 +1164,10 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa for n, param in model.named_parameters(): if "prompt_encoder." in n: # prompt tuning methods - assert param.grad is not None + if not issubclass(config_cls, CPTConfig): + assert param.grad is not None + elif "delta_embedding" in n: + assert param.grad is not None elif hasattr(model, "prefix") and (model.prefix in n): # non-prompt tuning methods assert param.grad is not None else: From dacb400e239d3198311e6c2d3665402a32784728 Mon Sep 17 00:00:00 2001 From: tsachiblau Date: Wed, 13 Nov 2024 20:50:39 +0100 Subject: [PATCH 14/14] Minor code updates. --- src/peft/peft_model.py | 3 +-- src/peft/tuners/cpt/config.py | 4 ++-- tests/testing_common.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 8cb755a2c7..547cf9defc 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1778,8 +1778,7 @@ def _cpt_forward( N_tokens = inputs_embeds.shape[1] else: N_tokens = input_ids.shape[1] - input_type_mask = torch.zeros((batch_size, N_tokens)).to(device) - input_type_mask[:, :] = 4 + input_type_mask = torch.ones((batch_size, N_tokens)).to(device) * 4 if peft_config.cpt_prompt_init == "TEXT": cpt_token_ids = peft_config.cpt_token_ids diff --git a/src/peft/tuners/cpt/config.py b/src/peft/tuners/cpt/config.py index bc2605d5af..9321ff1c59 100644 --- a/src/peft/tuners/cpt/config.py +++ b/src/peft/tuners/cpt/config.py @@ -14,7 +14,7 @@ import enum from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, Literal from peft.config import PromptLearningConfig from peft.utils import PeftType @@ -51,7 +51,7 @@ class CPTConfig(PromptLearningConfig): ) # Prompt tuning initialization method - cpt_prompt_init: Optional[str] = field( + cpt_prompt_init: Optional[Literal[CPTPromptInit.TEXT, CPTPromptInit.RANDOM]] = field( default="TEXT", metadata={"help": "Initialization method: 'TEXT' for embedding-based, 'RANDOM' for random."} ) diff --git a/tests/testing_common.py b/tests/testing_common.py index b99bb27d13..e16bf949ca 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -1166,7 +1166,7 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa if "prompt_encoder." in n: # prompt tuning methods if not issubclass(config_cls, CPTConfig): assert param.grad is not None - elif "delta_embedding" in n: + elif "delta_embedding" in n: # delta_embedding is the embedding that should be updated with grads in CPT assert param.grad is not None elif hasattr(model, "prefix") and (model.prefix in n): # non-prompt tuning methods assert param.grad is not None