Skip to content

Commit 2bfe54b

Browse files
BenjaminBossanBernardZach
authored andcommitted
[PEFT] Add warning for missing key in LoRA adapter (huggingface#34068)
When loading a LoRA adapter, so far, there was only a warning when there were unexpected keys in the checkpoint. Now, there is also a warning when there are missing keys. This change is consistent with huggingface/peft#2118 in PEFT and the planned PR huggingface/diffusers#9622 in diffusers. Apart from this change, the error message for unexpected keys was slightly altered for consistency (it should be more readable now). Also, besides adding a test for the missing keys warning, a test for unexpected keys warning was also added, as it was missing so far.
1 parent 0aad52e commit 2bfe54b

File tree

2 files changed

+96
-6
lines changed

2 files changed

+96
-6
lines changed

src/transformers/integrations/peft.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -235,13 +235,29 @@ def load_adapter(
235235
)
236236

237237
if incompatible_keys is not None:
238-
# check only for unexpected keys
238+
err_msg = ""
239+
origin_name = peft_model_id if peft_model_id is not None else "state_dict"
240+
# Check for unexpected keys.
239241
if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0:
240-
logger.warning(
241-
f"Loading adapter weights from {peft_model_id} led to unexpected keys not found in the model: "
242-
f" {incompatible_keys.unexpected_keys}. "
242+
err_msg = (
243+
f"Loading adapter weights from {origin_name} led to unexpected keys not found in the model: "
244+
f"{', '.join(incompatible_keys.unexpected_keys)}. "
243245
)
244246

247+
# Check for missing keys.
248+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
249+
if missing_keys:
250+
# Filter missing keys specific to the current adapter, as missing base model keys are expected.
251+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
252+
if lora_missing_keys:
253+
err_msg += (
254+
f"Loading adapter weights from {origin_name} led to missing keys in the model: "
255+
f"{', '.join(lora_missing_keys)}"
256+
)
257+
258+
if err_msg:
259+
logger.warning(err_msg)
260+
245261
# Re-dispatch model and hooks in case the model is offloaded to CPU / Disk.
246262
if (
247263
(getattr(self, "hf_device_map", None) is not None)

tests/peft_integration/test_peft_integration.py

+76-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from huggingface_hub import hf_hub_download
2121
from packaging import version
2222

23-
from transformers import AutoModelForCausalLM, OPTForCausalLM
23+
from transformers import AutoModelForCausalLM, OPTForCausalLM, logging
2424
from transformers.testing_utils import (
25+
CaptureLogger,
2526
require_bitsandbytes,
2627
require_peft,
2728
require_torch,
@@ -72,9 +73,15 @@ def test_peft_from_pretrained(self):
7273
This checks if we pass a remote folder that contains an adapter config and adapter weights, it
7374
should correctly load a model that has adapters injected on it.
7475
"""
76+
logger = logging.get_logger("transformers.integrations.peft")
77+
7578
for model_id in self.peft_test_model_ids:
7679
for transformers_class in self.transformers_test_model_classes:
77-
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
80+
with CaptureLogger(logger) as cl:
81+
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
82+
# ensure that under normal circumstances, there are no warnings about keys
83+
self.assertNotIn("unexpected keys", cl.out)
84+
self.assertNotIn("missing keys", cl.out)
7885

7986
self.assertTrue(self._check_lora_correctly_converted(peft_model))
8087
self.assertTrue(peft_model._hf_peft_config_loaded)
@@ -548,3 +555,70 @@ def test_peft_from_pretrained_hub_kwargs(self):
548555

549556
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
550557
self.assertTrue(self._check_lora_correctly_converted(model))
558+
559+
def test_peft_from_pretrained_unexpected_keys_warning(self):
560+
"""
561+
Test for warning when loading a PEFT checkpoint with unexpected keys.
562+
"""
563+
from peft import LoraConfig
564+
565+
logger = logging.get_logger("transformers.integrations.peft")
566+
567+
for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
568+
for transformers_class in self.transformers_test_model_classes:
569+
model = transformers_class.from_pretrained(model_id).to(torch_device)
570+
571+
peft_config = LoraConfig()
572+
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
573+
dummy_state_dict = torch.load(state_dict_path)
574+
575+
# add unexpected key
576+
dummy_state_dict["foobar"] = next(iter(dummy_state_dict.values()))
577+
578+
with CaptureLogger(logger) as cl:
579+
model.load_adapter(
580+
adapter_state_dict=dummy_state_dict, peft_config=peft_config, low_cpu_mem_usage=False
581+
)
582+
583+
msg = "Loading adapter weights from state_dict led to unexpected keys not found in the model: foobar"
584+
self.assertIn(msg, cl.out)
585+
586+
def test_peft_from_pretrained_missing_keys_warning(self):
587+
"""
588+
Test for warning when loading a PEFT checkpoint with missing keys.
589+
"""
590+
from peft import LoraConfig
591+
592+
logger = logging.get_logger("transformers.integrations.peft")
593+
594+
for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
595+
for transformers_class in self.transformers_test_model_classes:
596+
model = transformers_class.from_pretrained(model_id).to(torch_device)
597+
598+
peft_config = LoraConfig()
599+
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
600+
dummy_state_dict = torch.load(state_dict_path)
601+
602+
# remove a key so that we have missing keys
603+
key = next(iter(dummy_state_dict.keys()))
604+
del dummy_state_dict[key]
605+
606+
with CaptureLogger(logger) as cl:
607+
model.load_adapter(
608+
adapter_state_dict=dummy_state_dict,
609+
peft_config=peft_config,
610+
low_cpu_mem_usage=False,
611+
adapter_name="other",
612+
)
613+
614+
# Here we need to adjust the key name a bit to account for PEFT-specific naming.
615+
# 1. Remove PEFT-specific prefix
616+
# If merged after dropping Python 3.8, we can use: key = key.removeprefix(peft_prefix)
617+
peft_prefix = "base_model.model."
618+
key = key[len(peft_prefix) :]
619+
# 2. Insert adapter name
620+
prefix, _, suffix = key.rpartition(".")
621+
key = f"{prefix}.other.{suffix}"
622+
623+
msg = f"Loading adapter weights from state_dict led to missing keys in the model: {key}"
624+
self.assertIn(msg, cl.out)

0 commit comments

Comments
 (0)