-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Ensure weight tying is maintained for embed_tokens and lm_head #2803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
2ddc006
5ba62f7
aa67b3a
a17d8cd
1de6c5f
6108db1
0e2f966
32c393c
bae029f
7f9ce15
7b80354
2a1fa42
68cf10c
4696569
43098ae
15f2949
acf4ce0
c1d08c7
bc2d233
b3e29b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -690,6 +690,36 @@ def inject_adapter( | |||||
# in a bad (half-initialized) state. | ||||||
self._check_new_adapter_config(peft_config) | ||||||
|
||||||
modules_to_save = ( | ||||||
|
||||||
set(getattr(peft_config, "modules_to_save", [])) if getattr(peft_config, "modules_to_save", []) else set() | ||||||
|
set(getattr(peft_config, "modules_to_save", [])) if getattr(peft_config, "modules_to_save", []) else set() | |
set(getattr(peft_config, "modules_to_save", [])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modules_to_save
is set as None
in src/peft/mixed_model.py
which results in an error
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -508,10 +508,10 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper): | |
# All names of layers that may contain adapter (trainable) weights | ||
adapter_layer_names: tuple[str, ...] = ("modules_to_save",) | ||
|
||
def __init__(self, module_to_save, adapter_name): | ||
super().__init__(module_to_save, adapter_name) | ||
def __init__(self, module_to_save, adapter_name, tied_module=None): | ||
super().__init__(module_to_save, adapter_name, tied_module=tied_module) | ||
|
||
def init_modules(self, adapter_name): | ||
def init_modules(self, adapter_name, **kwargs): | ||
# we treat each adapter separately, so we have multiple adapters, same (copied) module for each | ||
self.modules_to_save = torch.nn.ModuleDict({}) | ||
|
||
|
@@ -548,9 +548,17 @@ def update(self, adapter_name, **kwargs): | |
context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0) | ||
break | ||
|
||
tied_module = kwargs.get("tied_module", None) | ||
|
||
|
||
if adapter_name not in self.modules_to_save: | ||
with context_manager: | ||
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module) | ||
if tied_module: | ||
new_linear = torch.nn.Linear(*tied_module.weight.shape, bias=False) | ||
new_linear.weight = tied_module.weight | ||
|
||
self.modules_to_save[adapter_name] = new_linear | ||
else: | ||
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module) | ||
|
||
if hasattr(self.modules_to_save[adapter_name], "_hf_hook"): | ||
old_hook = self.modules_to_save[adapter_name]._hf_hook | ||
|
@@ -1402,6 +1410,17 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n | |
activate_adapter=activate_adapter, | ||
) | ||
|
||
if getattr(peft_config, "modules_to_tie", None) is not None: | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tied_module = getattr(model.get_input_embeddings().modules_to_save, adapter_name) | ||
_set_trainable( | ||
model, | ||
adapter_name, | ||
inference_mode=peft_config.inference_mode, | ||
module_names=getattr(peft_config, "modules_to_tie", None), | ||
activate_adapter=activate_adapter, | ||
tied_module=tied_module, | ||
) | ||
|
||
if getattr(peft_config, "trainable_token_indices", None) is not None: | ||
if isinstance(peft_config.trainable_token_indices, dict): | ||
target_layers = peft_config.trainable_token_indices | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -110,6 +110,47 @@ def forward(self, x): | |
|
||
return MyModule().eval().to(self.torch_device) | ||
|
||
def get_lm_model(self, bias=True, tie_weights=True): | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class MyModule(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.linear = nn.Linear(1000, 1000, bias=bias) | ||
self.embed_tokens = nn.Embedding(1000, 1000) | ||
self.conv2d = nn.Conv2d(100, 100, 3, bias=bias) | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
def forward(self, x): | ||
x_int = (x * 100).int() | ||
x_4d = x.reshape(1, 100, 10, 10) | ||
|
||
return self.linear(x), self.embed(x_int), self.conv2d(x_4d) | ||
|
||
|
||
class CausalLM(nn.Module): | ||
if tie_weights: | ||
_tied_weights_keys = ["lm_head.weight"] | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.model = MyModule() | ||
self.config = {"tie_word_embeddings": tie_weights} | ||
|
||
if tie_weights: | ||
self.lm_head = nn.Linear(1000, 1000, bias=False) | ||
self.lm_head.weight = self.model.embed_tokens.weight | ||
else: | ||
self.lm_head = nn.Linear(1000, 1000, bias=bias) | ||
|
||
def forward(self, x): | ||
return self.model(x) | ||
|
||
|
||
def prepare_inputs_for_generation(self): | ||
return | ||
|
||
def get_input_embeddings(self): | ||
return self.model.embed_tokens | ||
|
||
return CausalLM().eval().to(self.torch_device) | ||
|
||
@pytest.fixture | ||
def data(self): | ||
return torch.rand(10, 1000).to(self.torch_device) | ||
|
@@ -1566,6 +1607,98 @@ def test_multiple_configs_with_bias_raises(self, tmp_path): | |
config2 = LoraConfig(target_modules=["linear"], bias="none") | ||
model.add_adapter("other", config2) # does not raise | ||
|
||
def test_weight_tying_tied_model(self): | ||
# If weight tying is enabled and `embed_tokens` | ||
# is passed as a `modules_to_save`, it needs to be ensured | ||
# that lm_head is tied to the adapter added to `embed_tokens` | ||
|
||
from peft.utils.other import ModulesToSaveWrapper | ||
|
||
|
||
model = self.get_lm_model() | ||
embed_token_config = LoraConfig( | ||
modules_to_save=["embed_tokens"], | ||
target_modules=["linear"], | ||
ensure_weight_tying=True, | ||
) | ||
model = get_peft_model(model, embed_token_config) | ||
|
||
assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( | ||
"Embed tokens is not added in Modules to Save" | ||
) | ||
assert type(model.base_model.model.model.embed_tokens) is type(model.base_model.model.lm_head), ( | ||
"Embed tokens and LM head types are not same" | ||
) | ||
|
||
# Validating that all model parameters are same | ||
embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters()) | ||
lm_head_np = dict(model.base_model.model.lm_head.named_parameters()) | ||
|
||
for k in embed_np.keys(): | ||
assert torch.allclose(embed_np[k], lm_head_np[k]) | ||
assert embed_np[k] is lm_head_np[k] | ||
|
||
def test_weight_tying_non_tied_model(self): | ||
from peft.utils.other import ModulesToSaveWrapper | ||
|
||
model = self.get_lm_model(tie_weights=False) | ||
embed_token_config = LoraConfig( | ||
modules_to_save=["embed_tokens"], | ||
target_modules=["linear"], | ||
ensure_weight_tying=True, | ||
) | ||
with pytest.warns(UserWarning, match="no tied modules were found in the model"): | ||
model = get_peft_model(model, embed_token_config) | ||
|
||
assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( | ||
"Embed tokens is not added in Modules to Save" | ||
) | ||
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( | ||
"LM head is not of type nn.linear" | ||
) | ||
|
||
def test_not_weight_tying_tied_model(self): | ||
from peft.utils.other import ModulesToSaveWrapper | ||
|
||
model = self.get_lm_model() | ||
embed_token_config = LoraConfig( | ||
modules_to_save=["embed_tokens"], | ||
target_modules=["linear"], | ||
ensure_weight_tying=False, | ||
) | ||
with pytest.warns(UserWarning, match="`ensure_weight_tying` is not set to True"): | ||
model = get_peft_model(model, embed_token_config) | ||
|
||
assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( | ||
"Embed tokens is not added in Modules to Save" | ||
) | ||
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( | ||
"LM head is not of type nn.linear" | ||
) | ||
|
||
def test_weight_tying_tied_model_no_embed(self): | ||
# If weight tying is enabled and `embed_tokens` | ||
# is passed as a `modules_to_save`, it needs to be ensured | ||
# that lm_head is tied to the adapter added to `embed_tokens` | ||
|
||
model = self.get_lm_model() | ||
embed_token_config = LoraConfig( | ||
target_modules=["linear"], | ||
ensure_weight_tying=True, | ||
) | ||
|
||
model = get_peft_model(model, embed_token_config) | ||
|
||
|
||
assert isinstance(model.base_model.model.model.embed_tokens, torch.nn.modules.Embedding) | ||
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear) | ||
|
||
# Validating that all model parameters are same | ||
embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters()) | ||
lm_head_np = dict(model.base_model.model.lm_head.named_parameters()) | ||
|
||
for k in embed_np.keys(): | ||
assert torch.allclose(embed_np[k], lm_head_np[k]) | ||
assert embed_np[k] is lm_head_np[k] | ||
|
||
|
||
class TestLokrInitialization: | ||
torch_device = infer_device() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's mention that right now, this only applied to
modules_to_save
, not the LoRA weights.