|
120 | 120 | ), |
121 | 121 | ("Conv2d 1 LoRA with lora_b bias", "Conv2d", LoraConfig, {"target_modules": ["conv2d"], "lora_bias": True}), |
122 | 122 | ("Conv3d 1 LoRA with lora_b bias", "Conv3d", LoraConfig, {"target_modules": ["conv3d"], "lora_bias": True}), |
| 123 | + ("MHA 1 LoRA", "MHA", LoraConfig, {"target_modules": ["mha"]}), |
| 124 | + ("MHA 2 LoRA", "MHA", LoraConfig, {"target_modules": ["mha", "lin0"]}), |
123 | 125 | ####### |
124 | 126 | # IA³ # |
125 | 127 | ####### |
@@ -872,6 +874,21 @@ def forward(self, X): |
872 | 874 | return X |
873 | 875 |
|
874 | 876 |
|
| 877 | +class ModelMha(nn.Module): |
| 878 | + def __init__(self): |
| 879 | + super().__init__() |
| 880 | + self.mha = nn.MultiheadAttention(10, 2) |
| 881 | + self.lin0 = nn.Linear(10, 2) |
| 882 | + self.sm = nn.LogSoftmax(dim=-1) |
| 883 | + |
| 884 | + def forward(self, X): |
| 885 | + X = X.float() |
| 886 | + X, _ = self.mha(X, X, X) |
| 887 | + X = self.lin0(X) |
| 888 | + X = self.sm(X) |
| 889 | + return X |
| 890 | + |
| 891 | + |
875 | 892 | class MockTransformerWrapper: |
876 | 893 | """Mock class to behave like a transformers model. |
877 | 894 |
|
@@ -908,6 +925,9 @@ def from_pretrained(cls, model_id, torch_dtype=None): |
908 | 925 | if model_id == "Conv2d2": |
909 | 926 | return ModelConv2D2().to(torch_dtype) |
910 | 927 |
|
| 928 | + if model_id == "MHA": |
| 929 | + return ModelMha().to(torch_dtype) |
| 930 | + |
911 | 931 | raise ValueError(f"model_id {model_id} not implemented") |
912 | 932 |
|
913 | 933 |
|
@@ -1074,12 +1094,13 @@ def test_only_params_are_updated(self, test_name, model_id, config_cls, config_k |
1074 | 1094 | model_before = copy.deepcopy(model) |
1075 | 1095 |
|
1076 | 1096 | model.train() |
1077 | | - # this high learning rate was found through testing to be necessary to avoid flakiness |
1078 | | - lr = ( |
1079 | | - 100.0 |
1080 | | - if (config_kwargs.get("use_dora") and model_id == "EmbConv1D") or issubclass(config_cls, VBLoRAConfig) |
1081 | | - else 0.5 |
1082 | | - ) |
| 1097 | + lr = 0.5 |
| 1098 | + if (config_kwargs.get("use_dora") and model_id == "EmbConv1D") or issubclass(config_cls, VBLoRAConfig): |
| 1099 | + # this high learning rate was found through testing to be necessary to avoid flakiness |
| 1100 | + lr = 100 |
| 1101 | + elif "mha" in model_id.lower(): |
| 1102 | + # we get exploding gradients with MHA when learning rate is too high |
| 1103 | + lr = 1e-3 |
1083 | 1104 | optimizer = torch.optim.SGD(model.parameters(), lr=lr) |
1084 | 1105 |
|
1085 | 1106 | # train at least 3 steps for all parameters to be updated (probably this is required because of symmetry |
@@ -1117,8 +1138,13 @@ def test_parameters_after_loading_model(self, test_name, model_id, config_cls, c |
1117 | 1138 | ) |
1118 | 1139 | model = get_peft_model(model, config) |
1119 | 1140 | model.train() |
1120 | | - lr = 0.5 if not config_kwargs.get("use_dora") else 0.1 # otherwise we get nan |
1121 | | - if issubclass(config_cls, VBLoRAConfig): |
| 1141 | + |
| 1142 | + lr = 0.5 |
| 1143 | + if config_kwargs.get("use_dora"): |
| 1144 | + lr = 0.1 # otherwise we get nan |
| 1145 | + elif "mha" in model_id.lower(): |
| 1146 | + lr = 1e-3 # we get exploding gradients with MHA when learning rate is too high |
| 1147 | + elif issubclass(config_cls, VBLoRAConfig): |
1122 | 1148 | lr = 0.01 # otherwise we get nan |
1123 | 1149 | optimizer = torch.optim.SGD(model.parameters(), lr=lr) |
1124 | 1150 |
|
@@ -1775,6 +1801,14 @@ def test_gpt2_dora_merge_and_unload_safe_merge(self): |
1775 | 1801 | # should not raise an error |
1776 | 1802 | model.merge_and_unload(safe_merge=True) |
1777 | 1803 |
|
| 1804 | + def test_unload_adapter_multihead_attention(self): |
| 1805 | + # MultiheadAttention has special logic for unloading, that logic is covered by this test |
| 1806 | + self._test_unload_adapter( |
| 1807 | + model_id="MHA", |
| 1808 | + config_cls=LoraConfig, |
| 1809 | + config_kwargs={"target_modules": ["mha"], "init_lora_weights": False}, |
| 1810 | + ) |
| 1811 | + |
1778 | 1812 | def test_dora_save_and_load_remapping(self): |
1779 | 1813 | # Here we test the refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a ModuleDict |
1780 | 1814 | # with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since we want the |
@@ -1810,6 +1844,37 @@ def test_dora_save_and_load_remapping(self): |
1810 | 1844 | for k in state_dict: |
1811 | 1845 | assert torch.allclose(state_dict[k], state_dict_loaded[k]) |
1812 | 1846 |
|
| 1847 | + @parameterized.expand([False, True]) |
| 1848 | + def test_mha_gradients_set_correctly(self, with_forward_call): |
| 1849 | + # check for this bug: https://github.com/huggingface/peft/issues/761#issuecomment-1893804738 |
| 1850 | + base_model = ModelMha() |
| 1851 | + config = LoraConfig(target_modules=["mha"]) |
| 1852 | + model = get_peft_model(base_model, config) |
| 1853 | + model = model.to(self.torch_device) |
| 1854 | + |
| 1855 | + if with_forward_call: |
| 1856 | + # after the merge-unmerge roundtrip happening in forward of lora MHA, the base weights should be set to |
| 1857 | + # requires_grad=False |
| 1858 | + inputs = self.prepare_inputs_for_testing() |
| 1859 | + model(**inputs) |
| 1860 | + |
| 1861 | + assert model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad is False |
| 1862 | + assert model.base_model.model.mha.base_layer.in_proj_weight.requires_grad is False |
| 1863 | + |
| 1864 | + # _restore_weights used to ignore the gradient, this checks that it is indeed considered |
| 1865 | + model.base_model.model.mha._restore_weights() |
| 1866 | + assert model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad is False |
| 1867 | + assert model.base_model.model.mha.base_layer.in_proj_weight.requires_grad is False |
| 1868 | + |
| 1869 | + model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad = True |
| 1870 | + model.base_model.model.mha.base_layer.in_proj_weight.requires_grad = True |
| 1871 | + assert model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad is True |
| 1872 | + assert model.base_model.model.mha.base_layer.in_proj_weight.requires_grad is True |
| 1873 | + |
| 1874 | + model.base_model.model.mha._restore_weights() |
| 1875 | + assert model.base_model.model.mha.base_layer.out_proj.base_layer.weight.requires_grad is True |
| 1876 | + assert model.base_model.model.mha.base_layer.in_proj_weight.requires_grad is True |
| 1877 | + |
1813 | 1878 |
|
1814 | 1879 | class TestMultiRankAdapter(unittest.TestCase): |
1815 | 1880 | """Tests related to multirank LoRA adapters""" |
@@ -3630,6 +3695,18 @@ def test_mixed_adapter_batches_lora_conv2d(self): |
3630 | 3695 | inputs = {"X": torch.arange(270).view(6, 5, 3, 3).to(self.torch_device)} |
3631 | 3696 | self.run_checks(peft_model, inputs) |
3632 | 3697 |
|
| 3698 | + def test_mixed_adapter_batches_mha_raises(self): |
| 3699 | + base_model = ModelMha().to(self.torch_device).eval() |
| 3700 | + config0 = LoraConfig(target_modules=["mha"], init_lora_weights=False) |
| 3701 | + config1 = LoraConfig(target_modules=["mha"], r=16, init_lora_weights=False) |
| 3702 | + peft_model = get_peft_model(base_model, config0, "adapter0").eval() |
| 3703 | + peft_model.add_adapter("adapter1", config1) |
| 3704 | + |
| 3705 | + inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} |
| 3706 | + msg = "lora.MultiheadAttention does not support mixed adapter batches" |
| 3707 | + with pytest.raises(TypeError, match=msg): |
| 3708 | + self.run_checks(peft_model, inputs) |
| 3709 | + |
3633 | 3710 | def test_mixed_adapter_batches_lora_length_mismatch_raises(self, mlp_lora): |
3634 | 3711 | inputs = { |
3635 | 3712 | "X": torch.arange(90).view(-1, 10).to(self.torch_device), |
|
0 commit comments