Skip to content

Commit ed3c828

Browse files
FIX: Avoid needless copy from modules_to_save (#2220)
Resolves #2206 The problem is that we keep a "global" modules_to_save on the model which contains all possible modules_to_save for each adapter. When the first adapter targets layer "foo" with modules_to_save and the second adapter targets "bar", then "foo" will create a copy of the original module for the second adapter, even though it's not needed. This does not change the result but is unnecessary and takes up memory. Thus it should be avoided.
1 parent 9c11a3e commit ed3c828

File tree

4 files changed

+32
-9
lines changed

4 files changed

+32
-9
lines changed

src/peft/mixed_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def set_modules_to_save(self, peft_config: PeftConfig, adapter_name: str) -> Non
251251
self.modules_to_save = set(modules_to_save)
252252
else:
253253
self.modules_to_save.update(modules_to_save)
254-
_set_trainable(self, adapter_name)
254+
_set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save)
255255

256256
def set_adapter(self, adapter_name: Union[str, list[str]]) -> None:
257257
"""

src/peft/peft_model.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,8 @@ def set_additional_trainable_modules(self, peft_config, adapter_name):
949949
self.modules_to_save = set(peft_config.modules_to_save)
950950
else:
951951
self.modules_to_save.update(peft_config.modules_to_save)
952-
_set_trainable(self, adapter_name) # this may add a new ModulesToSaveWrapper
952+
# this may add a new ModulesToSaveWrapper
953+
_set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save)
953954

954955
def get_layer_status(self) -> list[TunerLayerStatus]:
955956
"""Get the status of each adapter layer in the model.
@@ -1446,7 +1447,7 @@ def __init__(
14461447
break
14471448

14481449
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
1449-
_set_trainable(self, adapter_name)
1450+
_set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save)
14501451

14511452
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
14521453
"""
@@ -2237,7 +2238,7 @@ def __init__(
22372238
break
22382239

22392240
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
2240-
_set_trainable(self, adapter_name)
2241+
_set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save)
22412242

22422243
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
22432244
"""
@@ -2458,7 +2459,7 @@ def __init__(
24582459
break
24592460

24602461
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
2461-
_set_trainable(self, adapter_name)
2462+
_set_trainable(self, adapter_name, modules_to_save=peft_config.modules_to_save)
24622463

24632464
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
24642465
"""

src/peft/utils/other.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,10 @@ def update(self, adapter_name):
275275

276276
context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0)
277277
break
278-
with context_manager:
279-
self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))
278+
279+
if adapter_name not in self.modules_to_save:
280+
with context_manager:
281+
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module)
280282

281283
if hasattr(self.modules_to_save[adapter_name], "_hf_hook"):
282284
old_hook = self.modules_to_save[adapter_name]._hf_hook
@@ -416,10 +418,10 @@ def _freeze_adapter(model, adapter_name):
416418
p.requires_grad = False
417419

418420

419-
def _set_trainable(model, adapter_name):
421+
def _set_trainable(model, adapter_name, modules_to_save):
420422
key_list = [key for key, _ in model.named_modules()]
421423
for key in key_list:
422-
target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save)
424+
target_module_found = any(key.endswith(target_key) for target_key in modules_to_save)
423425
if target_module_found:
424426
parent, target, target_name = _get_submodules(model, key)
425427
if isinstance(target, ModulesToSaveWrapper):

tests/test_custom_models.py

+20
Original file line numberDiff line numberDiff line change
@@ -1608,6 +1608,26 @@ def test_multiple_adapters_seq_cls_mixed_modules_to_save_merging_adapters(self):
16081608
with pytest.raises(ValueError, match=msg):
16091609
model.add_weighted_adapter(["default", "other"], weights=[1.0, 1.0], adapter_name="merged")
16101610

1611+
def test_multiple_adapters_no_needless_copy_modules_to_save(self):
1612+
# See 2206
1613+
# The problem was that we keep a "global" modules_to_save on the model which contains all possible
1614+
# modules_to_save for each adapter. When the first adapter targets embed_tokens with modules_to_save and the
1615+
# second adapter targets lm_head, then embed_tokens will create a copy of the original module for the second
1616+
# adapter, even though it's not needed. The copy still acts as expected but uses unnecessary memory.
1617+
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
1618+
model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device)
1619+
config0 = LoraConfig(modules_to_save=["embed_tokens"])
1620+
config1 = LoraConfig(modules_to_save=["lm_head"])
1621+
model = get_peft_model(model, config0)
1622+
model.add_adapter("other", config1)
1623+
1624+
lm_head_keys = list(model.base_model.model.lm_head.modules_to_save.keys())
1625+
assert lm_head_keys == ["other"]
1626+
1627+
embed_token_keys = list(model.base_model.model.model.decoder.embed_tokens.modules_to_save.keys())
1628+
# before the fix, this would be: ['default', 'other']
1629+
assert embed_token_keys == ["default"]
1630+
16111631
def test_existing_model_card(self):
16121632
# ensure that if there is already a model card, it is not overwritten
16131633
model = MLP()

0 commit comments

Comments
 (0)