|
20 | 20 | from huggingface_hub import hf_hub_download
|
21 | 21 | from packaging import version
|
22 | 22 |
|
23 |
| -from transformers import AutoModelForCausalLM, OPTForCausalLM |
| 23 | +from transformers import AutoModelForCausalLM, OPTForCausalLM, logging |
24 | 24 | from transformers.testing_utils import (
|
| 25 | + CaptureLogger, |
25 | 26 | require_bitsandbytes,
|
26 | 27 | require_peft,
|
27 | 28 | require_torch,
|
@@ -72,9 +73,15 @@ def test_peft_from_pretrained(self):
|
72 | 73 | This checks if we pass a remote folder that contains an adapter config and adapter weights, it
|
73 | 74 | should correctly load a model that has adapters injected on it.
|
74 | 75 | """
|
| 76 | + logger = logging.get_logger("transformers.integrations.peft") |
| 77 | + |
75 | 78 | for model_id in self.peft_test_model_ids:
|
76 | 79 | for transformers_class in self.transformers_test_model_classes:
|
77 |
| - peft_model = transformers_class.from_pretrained(model_id).to(torch_device) |
| 80 | + with CaptureLogger(logger) as cl: |
| 81 | + peft_model = transformers_class.from_pretrained(model_id).to(torch_device) |
| 82 | + # ensure that under normal circumstances, there are no warnings about keys |
| 83 | + self.assertNotIn("unexpected keys", cl.out) |
| 84 | + self.assertNotIn("missing keys", cl.out) |
78 | 85 |
|
79 | 86 | self.assertTrue(self._check_lora_correctly_converted(peft_model))
|
80 | 87 | self.assertTrue(peft_model._hf_peft_config_loaded)
|
@@ -548,3 +555,70 @@ def test_peft_from_pretrained_hub_kwargs(self):
|
548 | 555 |
|
549 | 556 | model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
|
550 | 557 | self.assertTrue(self._check_lora_correctly_converted(model))
|
| 558 | + |
| 559 | + def test_peft_from_pretrained_unexpected_keys_warning(self): |
| 560 | + """ |
| 561 | + Test for warning when loading a PEFT checkpoint with unexpected keys. |
| 562 | + """ |
| 563 | + from peft import LoraConfig |
| 564 | + |
| 565 | + logger = logging.get_logger("transformers.integrations.peft") |
| 566 | + |
| 567 | + for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids): |
| 568 | + for transformers_class in self.transformers_test_model_classes: |
| 569 | + model = transformers_class.from_pretrained(model_id).to(torch_device) |
| 570 | + |
| 571 | + peft_config = LoraConfig() |
| 572 | + state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin") |
| 573 | + dummy_state_dict = torch.load(state_dict_path) |
| 574 | + |
| 575 | + # add unexpected key |
| 576 | + dummy_state_dict["foobar"] = next(iter(dummy_state_dict.values())) |
| 577 | + |
| 578 | + with CaptureLogger(logger) as cl: |
| 579 | + model.load_adapter( |
| 580 | + adapter_state_dict=dummy_state_dict, peft_config=peft_config, low_cpu_mem_usage=False |
| 581 | + ) |
| 582 | + |
| 583 | + msg = "Loading adapter weights from state_dict led to unexpected keys not found in the model: foobar" |
| 584 | + self.assertIn(msg, cl.out) |
| 585 | + |
| 586 | + def test_peft_from_pretrained_missing_keys_warning(self): |
| 587 | + """ |
| 588 | + Test for warning when loading a PEFT checkpoint with missing keys. |
| 589 | + """ |
| 590 | + from peft import LoraConfig |
| 591 | + |
| 592 | + logger = logging.get_logger("transformers.integrations.peft") |
| 593 | + |
| 594 | + for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids): |
| 595 | + for transformers_class in self.transformers_test_model_classes: |
| 596 | + model = transformers_class.from_pretrained(model_id).to(torch_device) |
| 597 | + |
| 598 | + peft_config = LoraConfig() |
| 599 | + state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin") |
| 600 | + dummy_state_dict = torch.load(state_dict_path) |
| 601 | + |
| 602 | + # remove a key so that we have missing keys |
| 603 | + key = next(iter(dummy_state_dict.keys())) |
| 604 | + del dummy_state_dict[key] |
| 605 | + |
| 606 | + with CaptureLogger(logger) as cl: |
| 607 | + model.load_adapter( |
| 608 | + adapter_state_dict=dummy_state_dict, |
| 609 | + peft_config=peft_config, |
| 610 | + low_cpu_mem_usage=False, |
| 611 | + adapter_name="other", |
| 612 | + ) |
| 613 | + |
| 614 | + # Here we need to adjust the key name a bit to account for PEFT-specific naming. |
| 615 | + # 1. Remove PEFT-specific prefix |
| 616 | + # If merged after dropping Python 3.8, we can use: key = key.removeprefix(peft_prefix) |
| 617 | + peft_prefix = "base_model.model." |
| 618 | + key = key[len(peft_prefix) :] |
| 619 | + # 2. Insert adapter name |
| 620 | + prefix, _, suffix = key.rpartition(".") |
| 621 | + key = f"{prefix}.other.{suffix}" |
| 622 | + |
| 623 | + msg = f"Loading adapter weights from state_dict led to missing keys in the model: {key}" |
| 624 | + self.assertIn(msg, cl.out) |
0 commit comments