-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Ensure weight tying is maintained for embed_tokens and lm_head #2803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
2ddc006
5ba62f7
aa67b3a
a17d8cd
1de6c5f
6108db1
0e2f966
32c393c
bae029f
7f9ce15
7b80354
2a1fa42
68cf10c
4696569
43098ae
15f2949
acf4ce0
c1d08c7
bc2d233
b3e29b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,7 @@ | |
SAFETENSORS_WEIGHTS_NAME, | ||
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, | ||
WEIGHTS_NAME, | ||
ModulesToSaveWrapper, | ||
PeftType, | ||
TaskType, | ||
_get_batch_size, | ||
|
@@ -1845,6 +1846,31 @@ def __init__( | |
super().__init__(model, peft_config, adapter_name, **kwargs) | ||
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation | ||
|
||
# Condition to check if embedding layer (`embed_tokens`) is added | ||
# in `modules_to_save` and we want to ensure the `lm_head` | ||
# does not diverge from the `embed_tokens` layer | ||
if ( | ||
peft_config.task_type == "CAUSAL_LM" | ||
|
||
and hasattr(model.get_input_embeddings(), "modules_to_save") | ||
and getattr(peft_config, "ensure_weight_tieing") | ||
): | ||
module_keys = BaseTuner._get_tied_modules_to_save(self, model) | ||
|
||
if not module_keys: | ||
warnings.warn("You have requested ensure_weight_tieing, but no tied modules were found") | ||
|
||
tied_module = getattr(model.get_input_embeddings().modules_to_save, adapter_name) | ||
|
||
_set_trainable( | ||
model, | ||
adapter_name, | ||
inference_mode=peft_config.inference_mode, | ||
module_names=module_keys, | ||
strict_module_check=True, | ||
wrapper_cls=ModulesToSaveWrapper, | ||
tied_module=tied_module, | ||
) | ||
|
||
def forward( | ||
self, | ||
input_ids=None, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -508,10 +508,10 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper): | |
# All names of layers that may contain adapter (trainable) weights | ||
adapter_layer_names: tuple[str, ...] = ("modules_to_save",) | ||
|
||
def __init__(self, module_to_save, adapter_name): | ||
super().__init__(module_to_save, adapter_name) | ||
def __init__(self, module_to_save, adapter_name, tied_module=None): | ||
super().__init__(module_to_save, adapter_name, tied_module=tied_module) | ||
|
||
def init_modules(self, adapter_name): | ||
def init_modules(self, adapter_name, **kwargs): | ||
# we treat each adapter separately, so we have multiple adapters, same (copied) module for each | ||
self.modules_to_save = torch.nn.ModuleDict({}) | ||
|
||
|
@@ -548,9 +548,17 @@ def update(self, adapter_name, **kwargs): | |
context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0) | ||
break | ||
|
||
tied_module = kwargs.get("tied_module", None) | ||
|
||
|
||
if adapter_name not in self.modules_to_save: | ||
with context_manager: | ||
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module) | ||
if tied_module: | ||
new_linear = torch.nn.Linear(*tied_module.weight.shape, bias=False) | ||
new_linear.weight = tied_module.weight | ||
|
||
self.modules_to_save[adapter_name] = new_linear | ||
else: | ||
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module) | ||
|
||
if hasattr(self.modules_to_save[adapter_name], "_hf_hook"): | ||
old_hook = self.modules_to_save[adapter_name]._hf_hook | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -110,6 +110,47 @@ def forward(self, x): | |
|
||
return MyModule().eval().to(self.torch_device) | ||
|
||
def get_lm_model(self, bias=True, tie_weights=True): | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class MyModule(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.linear = nn.Linear(1000, 1000, bias=bias) | ||
self.embed_tokens = nn.Embedding(1000, 1000) | ||
self.conv2d = nn.Conv2d(100, 100, 3, bias=bias) | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
def forward(self, x): | ||
x_int = (x * 100).int() | ||
x_4d = x.reshape(1, 100, 10, 10) | ||
|
||
return self.linear(x), self.embed(x_int), self.conv2d(x_4d) | ||
|
||
|
||
class CausalLM(nn.Module): | ||
if tie_weights: | ||
_tied_weights_keys = ["lm_head.weight"] | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.model = MyModule() | ||
self.config = {"tie_word_embeddings": tie_weights} | ||
|
||
if tie_weights: | ||
self.lm_head = nn.Linear(1000, 1000, bias=False) | ||
self.lm_head.weight = self.model.embed_tokens.weight | ||
else: | ||
self.lm_head = nn.Linear(1000, 1000, bias=bias) | ||
|
||
def forward(self, x): | ||
return self.model(x) | ||
|
||
|
||
def prepare_inputs_for_generation(self): | ||
return | ||
|
||
def get_input_embeddings(self): | ||
return self.model.embed_tokens | ||
|
||
return CausalLM().eval().to(self.torch_device) | ||
|
||
@pytest.fixture | ||
def data(self): | ||
return torch.rand(10, 1000).to(self.torch_device) | ||
|
@@ -1566,6 +1607,75 @@ def test_multiple_configs_with_bias_raises(self, tmp_path): | |
config2 = LoraConfig(target_modules=["linear"], bias="none") | ||
model.add_adapter("other", config2) # does not raise | ||
|
||
def test_weight_tieing_tied_model(self): | ||
# If weight tieing is enabled and `embed_tokens` | ||
# is passed as a `modules_to_save`, it needs to be ensured | ||
# that lm_head is tied to the adapter added to `embed_tokens` | ||
|
||
from peft.utils.other import ModulesToSaveWrapper | ||
|
||
|
||
model = self.get_lm_model() | ||
embed_token_config = LoraConfig( | ||
task_type="CAUSAL_LM", | ||
modules_to_save=["embed_tokens"], | ||
target_modules=["linear"], | ||
ensure_weight_tieing=True, | ||
) | ||
model = get_peft_model(model, embed_token_config) | ||
|
||
assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( | ||
"Embed tokens is not added in Modules to Save" | ||
) | ||
assert type(model.base_model.model.model.embed_tokens) is type(model.base_model.model.lm_head), ( | ||
"Embed tokens and LM head types are not same" | ||
) | ||
|
||
# Validating that all model parameters are same | ||
embed_np = dict(model.base_model.model.model.embed_tokens.named_parameters()) | ||
lm_head_np = dict(model.base_model.model.lm_head.named_parameters()) | ||
|
||
for k in embed_np.keys(): | ||
assert torch.allclose(embed_np[k], lm_head_np[k]) | ||
assert embed_np[k] is lm_head_np[k] | ||
|
||
def test_weight_tieing_non_tied_model(self): | ||
from peft.utils.other import ModulesToSaveWrapper | ||
|
||
model = self.get_lm_model(tie_weights=False) | ||
embed_token_config = LoraConfig( | ||
task_type="CAUSAL_LM", | ||
modules_to_save=["embed_tokens"], | ||
target_modules=["linear"], | ||
ensure_weight_tieing=True, | ||
) | ||
model = get_peft_model(model, embed_token_config) | ||
romitjain marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( | ||
"Embed tokens is not added in Modules to Save" | ||
) | ||
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( | ||
"LM head is not of type nn.linear" | ||
) | ||
|
||
def test_not_weight_tieing_tied_model(self): | ||
from peft.utils.other import ModulesToSaveWrapper | ||
|
||
model = self.get_lm_model() | ||
embed_token_config = LoraConfig( | ||
task_type="CAUSAL_LM", | ||
modules_to_save=["embed_tokens"], | ||
target_modules=["linear"], | ||
ensure_weight_tieing=False, | ||
) | ||
model = get_peft_model(model, embed_token_config) | ||
|
||
assert isinstance(model.base_model.model.model.embed_tokens, ModulesToSaveWrapper), ( | ||
"Embed tokens is not added in Modules to Save" | ||
) | ||
assert isinstance(model.base_model.model.lm_head, torch.nn.modules.linear.Linear), ( | ||
"LM head is not of type nn.linear" | ||
) | ||
|
||
|
||
class TestLokrInitialization: | ||
torch_device = infer_device() | ||
|
Uh oh!
There was an error while loading. Please reload this page.