-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add RWKV LoRA defaults and opt-in test #2810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,49 @@ | ||||||||||||||||||||||||
# Copyright 2025-present the HuggingFace Inc. team. | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There isn't really a need to add this test here, we don't have tests just to ensure that the default target modules are being set. I did, however, confirm that the test passes locally. I could see an argument to add it to the general test suite, since RWKV has a different architecture: peft/tests/test_decoder_models.py Lines 56 to 66 in 815956b
However, at the moment, the PEFT CI is already stressing the rate limit of HF Hub, so adding yet another model would not be a good idea. I think that if this situation relaxes and if we find that there is a big demand for RWKV finetuning with PEFT, we can consider that option. |
||||||||||||||||||||||||
# | ||||||||||||||||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||
# you may not use this file except in compliance with the License. | ||||||||||||||||||||||||
# You may obtain a copy of the License at | ||||||||||||||||||||||||
# | ||||||||||||||||||||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||
# | ||||||||||||||||||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||||
import os | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
import pytest | ||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
from peft import LoraConfig, TaskType, get_peft_model | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
if not os.getenv("PEFT_RUN_RWKV_TESTS"): | ||||||||||||||||||||||||
pytest.skip("RWKV tests are disabled by default; set PEFT_RUN_RWKV_TESTS=1 to enable.", allow_module_level=True) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
transformers = pytest.importorskip("transformers") | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
@pytest.mark.parametrize("seq_len", [4]) | ||||||||||||||||||||||||
def test_rwkv_lora_forward_backward(seq_len: int): | ||||||||||||||||||||||||
config = transformers.RwkvConfig( | ||||||||||||||||||||||||
hidden_size=32, | ||||||||||||||||||||||||
attention_hidden_size=32, | ||||||||||||||||||||||||
intermediate_size=64, | ||||||||||||||||||||||||
num_hidden_layers=2, | ||||||||||||||||||||||||
vocab_size=64, | ||||||||||||||||||||||||
context_length=seq_len, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
model = transformers.RwkvForCausalLM(config) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
lora_config = LoraConfig(r=4, lora_alpha=16, lora_dropout=0.0, task_type=TaskType.CAUSAL_LM) | ||||||||||||||||||||||||
model = get_peft_model(model, lora_config) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
input_ids = torch.randint(0, config.vocab_size, (2, seq_len)) | ||||||||||||||||||||||||
output = model(input_ids=input_ids) | ||||||||||||||||||||||||
loss = output.logits.float().mean() | ||||||||||||||||||||||||
loss.backward() | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
grads = [param.grad for name, param in model.named_parameters() if "lora_" in name and param.requires_grad] | ||||||||||||||||||||||||
assert grads and all(g is not None for g in grads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO this README entry can be removed, as we generally don't highlight individual models here.