Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,74 @@ def supports_lora_conversion(self, adapter_name: str = "default") -> bool:

return self.base_model.supports_lora_conversion()

def get_base_model_state_dict(self) -> dict[str, torch.Tensor]:
"""
Returns the state dict of the base model with the original model keys.

This method extracts the base model's parameters, removing PEFT-specific key modifications and filtering out
adapter-specific parameters (like LoRA matrices).

This is useful when you need to access or save the base model's weights with their original key names.

Returns:
`dict[str, torch.Tensor]`:
The base model's state dict with original keys (without PEFT modifications).

Example:
```python
>>> from transformers import AutoModelForCausalLM
>>> from peft import get_peft_model, LoraConfig

>>> base_model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> original_keys = set(base_model.state_dict().keys())

>>> peft_model = get_peft_model(base_model, LoraConfig(target_modules=["c_attn"]))
>>> base_state_dict = peft_model.get_base_model_state_dict()

>>> # The keys match the original model
>>> assert set(base_state_dict.keys()) == original_keys
```
"""
# for prompt learning methods the base model structure is not modified
if self._is_prompt_learning:
return dict(self.base_model.state_dict())

# Get state dict from the underlying model
state_dict = self.base_model.model.state_dict()

# Collect all adapter prefixes to identify adapter-specific parameters
adapter_prefixes: set[str] = set()
for config in self.peft_config.values():
prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(config.peft_type)
if prefix:
adapter_prefixes.add(prefix)

result: dict[str, torch.Tensor] = {}

for key, value in state_dict.items():
# skip adapter specific params such as .lora_A, .lora_B
is_adapter_param = False
for prefix in adapter_prefixes:
if f".{prefix}" in key or key.startswith(f"{prefix}"):
is_adapter_param = True
break

if is_adapter_param:
continue

# skip adapter-specific copies in modules_to_save
if ".modules_to_save." in key:
continue

# Transform keys to original format by removing PEFT-specific infixes
new_key = key
new_key = new_key.replace(".base_layer.", ".") # for tuner layers
new_key = new_key.replace(".original_module.", ".") # for modules_to_save

result[new_key] = value

return result


class PeftModelForSequenceClassification(PeftModel):
"""
Expand Down
114 changes: 114 additions & 0 deletions tests/test_get_base_model_state_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from transformers import AutoModelForCausalLM

from peft import LoraConfig, PromptTuningConfig, TaskType, get_peft_model
from peft.utils import infer_device

from .testing_utils import hub_online_once


def test_get_base_model_state_dict_matches():
# Test to check whether all the keys in the base model match to the keys
# of the lora wrapped model when calling get_base_model_state_dict method
model_id = "peft-internal-testing/tiny-random-OPTForCausalLM"
torch_device = infer_device()
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
base_model_keys = set(base_model.state_dict().keys())
lora_config = LoraConfig(r=4, lora_alpha=2, target_modules="all-linear", lora_dropout=0.1)
peft_model = get_peft_model(base_model, lora_config)
new_state_dict = set(peft_model.get_base_model_state_dict().keys())
assert base_model_keys == new_state_dict


def test_get_base_model_state_dict_values_match():
# Test that the actual tensor values match the original base model weights
model_id = "peft-internal-testing/tiny-random-OPTForCausalLM"
torch_device = infer_device()
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
original_state_dict = {k: v.clone() for k, v in base_model.state_dict().items()}

lora_config = LoraConfig(r=4, lora_alpha=2, target_modules="all-linear")
peft_model = get_peft_model(base_model, lora_config)

extracted_state_dict = peft_model.get_base_model_state_dict()

for key in original_state_dict:
assert key in extracted_state_dict
assert torch.allclose(original_state_dict[key], extracted_state_dict[key])


def test_get_base_model_state_dict_with_modules_to_save():
# Test that modules_to_save are handled correctly (filters .modules_to_save.
# keys and transforms .original_module. keys back to original format)
model_id = "peft-internal-testing/tiny-random-OPTForCausalLM"
torch_device = infer_device()
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)

base_model_keys = set(base_model.state_dict().keys())

lora_config = LoraConfig(
r=4,
lora_alpha=2,
target_modules="all-linear",
modules_to_save=["lm_head"],
)
peft_model = get_peft_model(base_model, lora_config)

extracted_keys = set(peft_model.get_base_model_state_dict().keys())
assert base_model_keys == extracted_keys


def test_get_base_model_state_dict_with_multiple_adapters():
# Test that base model state dict is correctly extracted when multiple
# adapters are present, ensuring all adapter params are filtered out
model_id = "peft-internal-testing/tiny-random-OPTForCausalLM"
torch_device = infer_device()
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)

base_model_keys = set(base_model.state_dict().keys())

lora_config_1 = LoraConfig(r=4, lora_alpha=2, target_modules=["q_proj", "v_proj"])
peft_model = get_peft_model(base_model, lora_config_1, adapter_name="adapter1")

lora_config_2 = LoraConfig(r=8, lora_alpha=4, target_modules=["k_proj", "out_proj"])
peft_model.add_adapter("adapter2", lora_config_2)

extracted_keys = set(peft_model.get_base_model_state_dict().keys())
assert base_model_keys == extracted_keys


def test_get_base_model_state_dict_prompt_learning():
# Test with prompt learning method
model_id = "peft-internal-testing/tiny-random-OPTForCausalLM"
torch_device = infer_device()
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)

base_model_keys = set(base_model.state_dict().keys())

prompt_config = PromptTuningConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=10,
)
peft_model = get_peft_model(base_model, prompt_config)

extracted_keys = set(peft_model.get_base_model_state_dict().keys())
assert base_model_keys == extracted_keys