-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CPT Tuner #2168
base: main
Are you sure you want to change the base?
CPT Tuner #2168
Conversation
Hi, thanks for creating this PR to add this new method to PEFT. I did not have time yet to do a review, but I wanted to alert you that with the provided information, you could be de-anonymized as paper author. Not sure if that's a big deal for the submission process, but just wanted to let you know. |
I plan to upload it to arXiv anyway, so that works for me. Thanks for letting me know. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR to add the Context-aware Prompt Tuning method to PEFT.
I plan to upload it to arXiv anyway, so that works for me. Thanks for letting me know.
It's now there, right: https://arxiv.org/abs/2410.17222? Let's add a link to the paper into the docstring of the config class for reference.
Something small I noticed:
VERA (Kopiczko et al., 2023) builds on LoRA by incorporating adaptive learning rates
I think this is not an accurate characterization of VeRA. Did you maybe mean to reference LoRA+ instead?
I reviewed your method and added a couple of comments, please check them out. On top of these, I have a more general question for my understanding: According to the paper, part of the reason why this method works will relates to the changes in the loss calculation and how the parameters are updated. Is this fully covered by the CPT implementation here or would users need to consider something in addition when defining their training loop?
Regarding the testing, thanks for including a few functional tests. Let's also add CPT to the general testing framework, similar to how we do it for prompt tuning:
Line 129 in fb6108a
"prompt_tuning": (PromptTuningConfig, CONFIG_TESTING_KWARGS[4]), |
However, IIUC, it only works for decoder models (causal LM), right? That means we need to create a separate PeftTestConfigManager = ClassInstantier(CLASSES_MAPPING)
instance that uses CLASSES_MAPPING
with CPT added on top. This instance should then be used in test_decoder_models.py
.
Finally, before merging this, we also need to add some docs and at least one example. However, we can work on those in a later iteration and iron out the implementation first.
src/peft/tuners/cpt/__init__.py
Outdated
@@ -0,0 +1,20 @@ | |||
# Copyright 2023-present the HuggingFace Inc. team. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copyright 2023-present the HuggingFace Inc. team. | |
# Copyright 2024-present the HuggingFace Inc. team. |
src/peft/tuners/cpt/config.py
Outdated
from peft.utils import PeftType | ||
|
||
|
||
class PromptTuningInit(str, enum.Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is identical to the PromptTuningInit
class from prompt_tuning/config.py
, right? I wonder if we should give it a different name to avoid confusion. If you think the options will always be the same, we can also import that class here instead.
src/peft/tuners/cpt/config.py
Outdated
""" | ||
|
||
# Token-related configurations | ||
CPT_token_ids: Optional[torch.Tensor] = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For variable names, let's avoid using capitalization. So this variable should be cpt_token_ids
, same with all the variables below.
Also, since this and the next arguments should not be None
, I think it makes more sense to remove the default=None
and to remove the Optional
type annotation. Then we don't need to check that in check_config
.
Finally, list[int]
would also be valid here, right?
tests/CPT_test.py
Outdated
|
||
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=collator) | ||
|
||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to catch the exception, just let it fail.
tests/CPT_test.py
Outdated
except Exception as e: | ||
pytest.fail(f"Training failed with error: {e}") | ||
|
||
assert torch.all(model.prompt_encoder.default.embedding.weight.data.clone().detach().cpu() == emb.cpu()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding, this is to test that these embeddings are frozen? Let's add a comment.
tests/CPT_test.py
Outdated
assert torch.all(norm_delta <= epsilon) | ||
|
||
|
||
def test_model_training_text(sst_data, global_tokenizer, collator, config_text): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comments to the test above.
tests/CPT_test.py
Outdated
assert torch.all((norm_delta == 0) == (~non_label_idx)) | ||
|
||
|
||
def test_model_batch_training_text(sst_data, global_tokenizer, collator, config_text): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same argument as for the tests above. Also, what exactly is different in this test, just the batch size? Why is it important to test batch size 1 and 2?
tests/CPT_test.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rename this to test_cpt.py
Thank you for the constructive feedback! 🙂
Sorry, I can’t find where I mentioned it. Could you please point me to it?
Yes, it is fully covered by this implementation, including both loss and projection. Full details are available in the demo (https://github.com/tsachiblau/CPT/tree/main/notebooks) and will also be added to this repository to make it easier to use.
Yes, we only support causal LMs. Adding
Let’s handle it later.
If the user choose RANDOM option then this values should be None. As for the remaining comments, I’ve addressed them all and pushed the changes to the branch. |
… created _cpt_forward for readability, updated copyright to 2024, renamed class to CPTPromptInit, changed config variables to lowercase and list[int], removed exception catch from tests, added assertion docs, removed batch_size=1 test, and renamed test file to test_cpt.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the updates to the PR.
Something small I noticed:
Sorry, I can’t find where I mentioned it. Could you please point me to it?
This is in section 2 of the paper, 2nd paragraph starting with "Efficient Fine-Tuning".
Yes, it is fully covered by this implementation, including both loss and projection.
Okay, nice.
Yes, we only support causal LMs. Adding test_CPT to testing_common.py causes errors because the configuration is not initialized correctly, and I’m unsure how to address this. Could you clarify how to correctly include my tests?
So I think the following approach should work. See this line:
Line 202 in b3176ef
PeftTestConfigManager = ClassInstantier(CLASSES_MAPPING) |
Let's create a new instance below called PeftTestConfigManagerForDecoderModels
. You instantiate it the same, but with the class mapping extended to add CPT:
PeftTestConfigManagerForDecoderModels = ClassInstantier({**CLASSES_MAPPING, **DECODER_MODELS_EXTRA})
Of course, we need to define DECODER_MODELS_EXTRA
, which should be:
DECODER_MODELS_EXTRA = {
"cpt": (CPTConfig, CONFIG_TESTING_KWARGS[12])
}
You can add this after the definition of CLASSES_MAPPING
:
Line 124 in b3176ef
CLASSES_MAPPING = { |
Next, inside of test_decoder_models.py
, we have to use this new class:
- from .testing_common import PeftCommonTester, PeftTestConfigManager
+ from .testing_common import PeftCommonTester, PeftTestConfigManagerForDecoderModels as PeftTestConfigManager
Let me know if you have further questions.
As for the remaining comments, I’ve addressed them all and pushed the changes to the branch.
After all the changes have been made, make sure to also call make style
to make the linter happy.
There are still a couple of unaddressed comments, please check again.
src/peft/tuners/cpt/config.py
Outdated
) | ||
|
||
# Prompt tuning initialization method | ||
cpt_prompt_tuning_init: Optional[str] = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be type annotated as CPTPromptInit
.
""" | ||
self.peft_type = PeftType.CPT # Specifies that the PEFT type is CPT. | ||
self.target_modules = None # Placeholder for target modules in CPT. | ||
self.task_type = "CAUSAL_LM" # Ensures task type is causal language modeling. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, since this and the next arguments should not be None, I think it makes more sense to remove the default=None and to remove the Optional type annotation. Then we don't need to check that in check_config.
If the user choose RANDOM option then this values should be None.
Okay, let's add a check here to ensure that the argument is correctly set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this comment is still relevant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I mean is that a check should be performed here which checks what you mentioned in the quote above.
Thanks, I will check it out. I've made the other requested changes. :) |
…lization in config. Renamed cpt_prompt_tuning_init to cpt_prompt_init. Changed the class from PeftConfig to PromptLearningConfig. model: Removed check_config function. peft_model: Fixed bugs. tests: Added PeftTestConfigManagerForDecoderModels in test_decoder_models.py and testing_common.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates, this PR is approaching the finish line. Please ensure to always run make style
though, or else the linter will complain and tests won't run (I checked them locally and they passed).
Now, let's get back to what I mentioned earlier, namely adding some docs and an example. To give a recent example of a PR with docs and examples, check this one:
https://github.com/huggingface/peft/pull/2172/files
The example doesn't have to be this elaborate, but it should be something that users can easily adopt to their own uses cases. Maybe you can add something that resembles one of the experiments from the paper. That way, we can use the example to ensure that the experiment can be replicated with the PEFT implementation.
When writing the docs, put yourself in the shoes of a user who may not have read the paper and might be curious why they should consider this method.
tests/testing_common.py
Outdated
list_names.append(name) | ||
else: | ||
assert param.grad is None | ||
'' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
tests/testing_common.py
Outdated
@@ -1189,10 +1202,24 @@ def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwar | |||
loss = output.sum() | |||
loss.backward() | |||
|
|||
if issubclass(config_cls, CPTConfig): | |||
parameters = [] | |||
list_names = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need list_names
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is redundant
tests/testing_common.py
Outdated
parameters = [] | ||
list_names = [] | ||
for name, param in model.prompt_encoder.named_parameters(): | ||
if name not in ['default.embedding.weight']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if name not in ['default.embedding.weight']: | |
if name != "default.embedding.weight": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
tests/test_cpt.py
Outdated
MODEL_NAME = "bigscience/bloom-1b7" | ||
MAX_INPUT_LENGTH = 1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For unit testing, we should use much smaller models. Check what the other tests are using. One possibility would be "hf-internal-testing/tiny-random-OPTForCausalLM"
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
tests/testing_common.py
Outdated
assert param.grad is not None | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
tests/test_cpt.py
Outdated
|
||
def test_model_initialization_text(global_tokenizer, config_text): | ||
"""Test model loading and PEFT model initialization.""" | ||
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=".", trust_remote_code=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=".", trust_remote_code=True) | |
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
Same below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
I added the documentation, except for an example, as I'm unsure where to place it. I noticed that some models have examples in the /examples/ directory, but I can't find a way to access these examples from https://huggingface.co/docs/peft/. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the recent updates. Please ensure to run make style
so that the linter is happy.
except for an example, as I'm unsure where to place it. I noticed that some models have examples in the /examples/ directory, but I can't find a way to access these examples from https://huggingface.co/docs/peft/.
Yes, examples should go into the examples/
directory. For this method, examples/causal_language_modeling/
could be a good option.
I'm not sure why you want to access the examples from the docs. The docs can link to the example, but I don't understand what else you would like to achieve there.
docs/source/package_reference/cpt.md
Outdated
|
||
[[autodoc]] tuners.cpt.config.CPTConfig | ||
|
||
## CPTModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no CPTModel
, only CPTEmbedding
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I add a link to an example in the cpt.md file? I didn't see any other methods linked to examples. If you have an example, it would be a great help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I see no reason not to add a link there. Just be aware that the link would be invalid until the PR is merged.
@tsachiblau Heads up, another PR was merged to PEFT which added a new method, resulting in a bunch of merge conflicts, but they should be easy to resolve. LMK if you have questions. |
Merge is done |
Anything else that needs to be done? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@tsachiblau Could you please run |
Done |
@tsachiblau The linter is still complaining. Do you have the right ruff version installed? It should be v0.6.9. |
@tsachiblau ruff has now passed successfully but doc-builder is complaining, can you try running it?
|
Done |
The doc builder is still complaining, can you successfully run |
Yes, I get this message:
It seems to fail on the error handling that we implemented, such as
|
Ah yes, somehow I looked at an old log, now the linting indeed passes but the tests are failing. Currently I can't investigate why they're failing, but if the |
Done again :) |
@tsachiblau A couple of tests are failing. Mostly that concerns a test with gradient chechkpointing that checks the existence of a gradient on the prompt encoder. Could you check if this is a false alarm and the test needs adapting or if something else is going on? There is also another failing test during initialization. You can check the logs of the CI for more details. Thanks. |
Lets try again |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates. The tests are now passing (some flaky tests are failing, but we can ignore those for how). I did another check of the PR and found some smaller areas for improvement. Moreover, I saw that some of my previous comments are still unaddressed. If you disagree with a suggestion I made, just let me know, not everything needs to be changed, but if there is no reply I can't tell if you read it or not.
src/peft/peft_model.py
Outdated
@@ -1779,7 +1779,7 @@ def _cpt_forward( | |||
else: | |||
N_tokens = input_ids.shape[1] | |||
input_type_mask = torch.zeros((batch_size, N_tokens)).to(device) | |||
input_type_mask[:, -1] = 4 | |||
input_type_mask[:, :] = 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_type_mask.fill_(4)
would also work. Could you add a short comment on what "4" means here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 is the id for the tokens used for the loss calculation. I changed the code to
input_type_mask = torch.ones((batch_size, N_tokens)).to(device) * 4
""" | ||
self.peft_type = PeftType.CPT # Specifies that the PEFT type is CPT. | ||
self.target_modules = None # Placeholder for target modules in CPT. | ||
self.task_type = "CAUSAL_LM" # Ensures task type is causal language modeling. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this comment is still relevant.
src/peft/tuners/cpt/config.py
Outdated
) | ||
|
||
# Prompt tuning initialization method | ||
cpt_prompt_init: Optional[str] = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using Literal["TEXT", "RANDOM"]
as type annotation would be a bit more precise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this comment is still relevant.
It already exists in the code.
Using Literal["TEXT", "RANDOM"] as type annotation would be a bit more precise.
I changed it.
) | ||
|
||
# Loss-related configurations | ||
opt_weighted_loss_type: Optional[str] = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still relevant
# Virtual token configurations | ||
num_virtual_tokens: int = field(default=0, metadata={"help": "Number of virtual tokens used in the prompt."}) | ||
|
||
# CPT-specific static attributes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT about this suggestion?
|
||
return epsilon | ||
|
||
def projection(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT?
Can you please explain these points? I do not get what you suggest. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the latest updates.
Can you please explain these points? I do not get what you suggest.
I tried to clarify the open comments. I'm not sure if you can see the full context. If not, go to https://github.com/huggingface/peft/pull/2168/files and scroll down, you should see the full context of my comments.
) | ||
|
||
# Loss-related configurations | ||
opt_weighted_loss_type: Optional[str] = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My suggestion is to change the type annotation to Literal["none", "decay"]
.
) | ||
|
||
# Virtual token configurations | ||
num_virtual_tokens: int = field(default=0, metadata={"help": "Number of virtual tokens used in the prompt."}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having 0 as the default here makes little sense. WDYT about using a good default here, say, 10?
# Virtual token configurations | ||
num_virtual_tokens: int = field(default=0, metadata={"help": "Number of virtual tokens used in the prompt."}) | ||
|
||
# CPT-specific static attributes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it ever make sense to let users pass these arguments? If not, I would remove them here and place them inside the __post_init__
method.
""" | ||
self.peft_type = PeftType.CPT # Specifies that the PEFT type is CPT. | ||
self.target_modules = None # Placeholder for target modules in CPT. | ||
self.task_type = "CAUSAL_LM" # Ensures task type is causal language modeling. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I mean is that a check should be performed here which checks what you mentioned in the quote above.
""" | ||
if self.config.CPT_prompt_tuning_init == PromptTuningInit.TEXT: | ||
tensor_ICL_mask = torch.Tensor(self.config.CPT_tokens_type_mask).long() | ||
mask_input_template = torch.remainder(tensor_ICL_mask, 4) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bumping this comment.
|
||
return epsilon | ||
|
||
def projection(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My suggestion is to rename this method to get_projection
. Then, in the last line, instead of self.delta_embedding.weight.data = new_embeddings_weights
, just return new_embeddings_weights
. It is then on the caller side that the delta_embeddings
are updated.
base_model_output (ModelOutput): Output from the base model containing logits. | ||
labels (torch.Tensor): Ground-truth labels for the input tokens. | ||
CPT_type_mask (torch.Tensor): Token type mask used for filtering valid loss terms. | ||
config (Namespace): Configuration object containing loss-related hyperparameters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bump.
ModelOutput: The base model output with computed loss. | ||
""" | ||
|
||
if config.opt_weighted_loss_type in ["decay"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bump.
# Compute the weighted mean loss | ||
loss = (loss[shift_labels_bool] * shift_labels_weights[shift_labels_bool]).mean() | ||
base_model_output.loss = loss | ||
elif config.opt_weighted_loss_type not in ["none"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not if config.opt_weighted_loss_type == "none":
?
Hey,
This pull request introduces Context-Aware Prompt Tuning (CPT), a new and effective technique that builds on In-Context Learning (ICL) and Prompt Tuning (PT) with enhancements through adversarial optimization. CPT allows for better generalization and stability on various classification tasks.
The approach is based on a research paper, which will soon be available. The core idea of CPT is demonstrated and implemented in the following repository:
https://github.com/tsachiblau/CPT.
We are submitting this pull request to integrate the CPT method into the PEFT library, allowing users to experiment with this novel method. Thank you for reviewing this contribution!
The paper is attached
Context_aware_Prompt_Tuning__Advancing_In_Context_Learning_with_Adversarial_Methods_PEFT.pdf
Thanks,
Tsachi