Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# prints something like: Preheat the oven to 350 degrees and place the cookie dough in a baking dish [...]
```

> [!NOTE]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this README entry can be removed, as we generally don't highlight individual models here.

> Transformer-based RWKV checkpoints (Transformers 4.56+) pick up LoRA defaults automatically. For a minimal
> adapter:

```python
from transformers import RwkvConfig, RwkvForCausalLM
from peft import LoraConfig, TaskType, get_peft_model

config = RwkvConfig(hidden_size=512, num_hidden_layers=12, context_length=512)
model = RwkvForCausalLM(config)

# target_modules are inferred from the RWKV config; no manual list required
lora_config = LoraConfig(r=8, lora_alpha=16, task_type=TaskType.CAUSAL_LM)
model = get_peft_model(model, lora_config)
```

## Why you should use PEFT

There are many benefits of using PEFT but the main one is the huge savings in compute and storage, making PEFT applicable to many different use cases.
Expand Down
2 changes: 2 additions & 0 deletions src/peft/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def starcoder_model_postprocess_past_key_value(past_key_values):
"gemma3_text": ["q_proj", "v_proj"],
"qwen2": ["q_proj", "v_proj"],
"qwen3": ["q_proj", "v_proj"],
"rwkv": ["key", "value", "receptance", "output"],
"rwkv7": ["r_proj", "k_proj", "v_proj", "o_proj", "key", "value"],
}

# target module mappings that are identical to LORA
Expand Down
49 changes: 49 additions & 0 deletions tests/test_rwkv_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2025-present the HuggingFace Inc. team.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't really a need to add this test here, we don't have tests just to ensure that the default target modules are being set. I did, however, confirm that the test passes locally.

I could see an argument to add it to the general test suite, since RWKV has a different architecture:

PEFT_DECODER_MODELS_TO_TEST = [
"hf-internal-testing/tiny-random-OPTForCausalLM",
"hf-internal-testing/tiny-random-GPT2LMHeadModel",
"hf-internal-testing/tiny-random-BloomForCausalLM",
"hf-internal-testing/tiny-random-gpt_neo",
"hf-internal-testing/tiny-random-GPTJForCausalLM",
"hf-internal-testing/tiny-random-GPTBigCodeForCausalLM",
"trl-internal-testing/tiny-random-LlamaForCausalLM",
"peft-internal-testing/tiny-dummy-qwen2",
"hf-internal-testing/tiny-random-Gemma3ForCausalLM",
]

However, at the moment, the PEFT CI is already stressing the rate limit of HF Hub, so adding yet another model would not be a good idea. I think that if this situation relaxes and if we find that there is a big demand for RWKV finetuning with PEFT, we can consider that option.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch

from peft import LoraConfig, TaskType, get_peft_model


if not os.getenv("PEFT_RUN_RWKV_TESTS"):
pytest.skip("RWKV tests are disabled by default; set PEFT_RUN_RWKV_TESTS=1 to enable.", allow_module_level=True)

transformers = pytest.importorskip("transformers")


@pytest.mark.parametrize("seq_len", [4])
def test_rwkv_lora_forward_backward(seq_len: int):
config = transformers.RwkvConfig(
hidden_size=32,
attention_hidden_size=32,
intermediate_size=64,
num_hidden_layers=2,
vocab_size=64,
context_length=seq_len,
)
model = transformers.RwkvForCausalLM(config)

lora_config = LoraConfig(r=4, lora_alpha=16, lora_dropout=0.0, task_type=TaskType.CAUSAL_LM)
model = get_peft_model(model, lora_config)

input_ids = torch.randint(0, config.vocab_size, (2, seq_len))
output = model(input_ids=input_ids)
loss = output.logits.float().mean()
loss.backward()

grads = [param.grad for name, param in model.named_parameters() if "lora_" in name and param.requires_grad]
assert grads and all(g is not None for g in grads)