-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Ensure weight tying is maintained for embed_tokens and lm_head #2803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
2ddc006
5ba62f7
aa67b3a
a17d8cd
1de6c5f
6108db1
0e2f966
32c393c
bae029f
7f9ce15
7b80354
2a1fa42
68cf10c
4696569
43098ae
15f2949
acf4ce0
c1d08c7
bc2d233
b3e29b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,7 @@ | |
SAFETENSORS_WEIGHTS_NAME, | ||
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, | ||
WEIGHTS_NAME, | ||
ModulesToSaveWrapper, | ||
PeftType, | ||
TaskType, | ||
_get_batch_size, | ||
|
@@ -1848,6 +1849,27 @@ def __init__( | |
super().__init__(model, peft_config, adapter_name, **kwargs) | ||
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation | ||
|
||
# Condition to check if embedding layer (`embed_tokens`) is added | ||
# in `modules_to_save` and we want to ensure the `lm_head` | ||
# does not diverge from the `embed_tokens` layer | ||
if ( | ||
peft_config.task_type == "CAUSAL_LM" | ||
|
||
and hasattr(model.get_input_embeddings(), "modules_to_save") | ||
and getattr(peft_config, "ensure_weight_tieing") | ||
): | ||
module_keys = BaseTuner._get_tied_modules_to_save(self, model) | ||
tied_module = getattr(model.get_input_embeddings().modules_to_save, adapter_name) | ||
|
||
_set_trainable( | ||
model, | ||
adapter_name, | ||
inference_mode=peft_config.inference_mode, | ||
module_names=module_keys, | ||
strict_module_check=True, | ||
wrapper_cls=ModulesToSaveWrapper, | ||
tied_module=tied_module, | ||
) | ||
|
||
def forward( | ||
self, | ||
input_ids=None, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -506,10 +506,10 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper): | |
# All names of layers that may contain adapter (trainable) weights | ||
adapter_layer_names: tuple[str, ...] = ("modules_to_save",) | ||
|
||
def __init__(self, module_to_save, adapter_name): | ||
super().__init__(module_to_save, adapter_name) | ||
def __init__(self, module_to_save, adapter_name, tied_module=None): | ||
super().__init__(module_to_save, adapter_name, tied_module=tied_module) | ||
|
||
def init_modules(self, adapter_name): | ||
def init_modules(self, adapter_name, **kwargs): | ||
# we treat each adapter separately, so we have multiple adapters, same (copied) module for each | ||
self.modules_to_save = torch.nn.ModuleDict({}) | ||
|
||
|
@@ -546,9 +546,17 @@ def update(self, adapter_name, **kwargs): | |
context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0) | ||
break | ||
|
||
tied_module = kwargs.get("tied_module", None) | ||
|
||
|
||
if adapter_name not in self.modules_to_save: | ||
with context_manager: | ||
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module) | ||
if tied_module: | ||
new_linear = torch.nn.Linear(*tied_module.weight.shape, bias=False) | ||
new_linear.weight = tied_module.weight | ||
|
||
self.modules_to_save[adapter_name] = new_linear | ||
else: | ||
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module) | ||
|
||
if hasattr(self.modules_to_save[adapter_name], "_hf_hook"): | ||
old_hook = self.modules_to_save[adapter_name]._hf_hook | ||
|
Uh oh!
There was an error while loading. Please reload this page.