From 378f30f0b6840480b12d9e19fea2c0b48012476b Mon Sep 17 00:00:00 2001 From: Nir Ben-Or Date: Wed, 1 Oct 2025 22:21:46 -0500 Subject: [PATCH] Add RWKV LoRA defaults and opt-in test --- README.md | 16 ++++++++++++ src/peft/utils/constants.py | 2 ++ tests/test_rwkv_lora.py | 49 +++++++++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+) create mode 100644 tests/test_rwkv_lora.py diff --git a/README.md b/README.md index 77a61c68a3..887ca2b75e 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,22 @@ print(tokenizer.decode(outputs[0], skip_special_tokens=True)) # prints something like: Preheat the oven to 350 degrees and place the cookie dough in a baking dish [...] ``` +> [!NOTE] +> Transformer-based RWKV checkpoints (Transformers 4.56+) pick up LoRA defaults automatically. For a minimal +> adapter: + +```python +from transformers import RwkvConfig, RwkvForCausalLM +from peft import LoraConfig, TaskType, get_peft_model + +config = RwkvConfig(hidden_size=512, num_hidden_layers=12, context_length=512) +model = RwkvForCausalLM(config) + +# target_modules are inferred from the RWKV config; no manual list required +lora_config = LoraConfig(r=8, lora_alpha=16, task_type=TaskType.CAUSAL_LM) +model = get_peft_model(model, lora_config) +``` + ## Why you should use PEFT There are many benefits of using PEFT but the main one is the huge savings in compute and storage, making PEFT applicable to many different use cases. diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index c9d88df5a6..188c2d5de6 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -100,6 +100,8 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "gemma3_text": ["q_proj", "v_proj"], "qwen2": ["q_proj", "v_proj"], "qwen3": ["q_proj", "v_proj"], + "rwkv": ["key", "value", "receptance", "output"], + "rwkv7": ["r_proj", "k_proj", "v_proj", "o_proj", "key", "value"], } # target module mappings that are identical to LORA diff --git a/tests/test_rwkv_lora.py b/tests/test_rwkv_lora.py new file mode 100644 index 0000000000..fb86eafee9 --- /dev/null +++ b/tests/test_rwkv_lora.py @@ -0,0 +1,49 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import torch + +from peft import LoraConfig, TaskType, get_peft_model + + +if not os.getenv("PEFT_RUN_RWKV_TESTS"): + pytest.skip("RWKV tests are disabled by default; set PEFT_RUN_RWKV_TESTS=1 to enable.", allow_module_level=True) + +transformers = pytest.importorskip("transformers") + + +@pytest.mark.parametrize("seq_len", [4]) +def test_rwkv_lora_forward_backward(seq_len: int): + config = transformers.RwkvConfig( + hidden_size=32, + attention_hidden_size=32, + intermediate_size=64, + num_hidden_layers=2, + vocab_size=64, + context_length=seq_len, + ) + model = transformers.RwkvForCausalLM(config) + + lora_config = LoraConfig(r=4, lora_alpha=16, lora_dropout=0.0, task_type=TaskType.CAUSAL_LM) + model = get_peft_model(model, lora_config) + + input_ids = torch.randint(0, config.vocab_size, (2, seq_len)) + output = model(input_ids=input_ids) + loss = output.logits.float().mean() + loss.backward() + + grads = [param.grad for name, param in model.named_parameters() if "lora_" in name and param.requires_grad] + assert grads and all(g is not None for g in grads)