Skip to content

Commit b1fd97d

Browse files
authored
FIX Several bugs in LoKr (#2180)
- Added rank_dropout_scale parameter - Fix scale related corrections - Added lycoris weight initialization
1 parent 13fb29f commit b1fd97d

File tree

5 files changed

+125
-9
lines changed

5 files changed

+125
-9
lines changed

src/peft/tuners/lokr/config.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
from dataclasses import dataclass, field
17-
from typing import Optional, Union
17+
from typing import Literal, Optional, Union
1818

1919
from peft.tuners.lycoris_utils import LycorisConfig
2020
from peft.utils import PeftType
@@ -40,6 +40,8 @@ class LoKrConfig(LycorisConfig):
4040
Perform rank decomposition of left kronecker product matrix.
4141
decompose_factor (`int`):
4242
Kronecker product decomposition factor.
43+
rank_dropout_scale ('bool)
44+
Whether to scale the rank dropout while training, defaults to `False`.
4345
target_modules (`Optional[Union[List[str], str]]`):
4446
The names of the modules to apply the adapter to. If this is specified, only the modules with the specified
4547
names will be replaced. When passing a string, a regex match will be performed. When passing a list of
@@ -53,8 +55,8 @@ class LoKrConfig(LycorisConfig):
5355
When passing a list of strings, either an exact match will be performed or it is checked if the name of the
5456
module ends with any of the passed strings.
5557
init_weights (`bool`):
56-
Whether to perform initialization of adapter weights. This defaults to `True`, passing `False` is
57-
discouraged.
58+
Whether to perform initialization of adapter weights. This defaults to `True`. Use "lycoris" to initialize
59+
weights in the style of the LYCORIS repository. Passing `False` is discouraged.
5860
layers_to_transform (`Union[List[int], int]`):
5961
The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices
6062
that are specified in this list. If a single integer is passed, it will apply the transformations on the
@@ -91,6 +93,7 @@ class LoKrConfig(LycorisConfig):
9193
metadata={"help": "Perform rank decomposition of left kronecker product matrix."},
9294
)
9395
decompose_factor: int = field(default=-1, metadata={"help": "Kronecker product decomposition factor."})
96+
rank_dropout_scale: bool = field(default=False, metadata={"help": "Rank dropout scale"})
9497
target_modules: Optional[Union[list[str], str]] = field(
9598
default=None,
9699
metadata={
@@ -103,12 +106,12 @@ class LoKrConfig(LycorisConfig):
103106
default=None,
104107
metadata={"help": "List of module names or regex expression of the module names to exclude from LoKr."},
105108
)
106-
init_weights: bool = field(
109+
init_weights: Union[bool, Literal["lycoris"]] = field(
107110
default=True,
108111
metadata={
109112
"help": (
110-
"Whether to initialize the weights of the LoKr layers with their default initialization. Don't change "
111-
"this setting, except if you know exactly what you're doing."
113+
"Whether to initialize the weights of the LoKr layers with their default initialization. Can be True, False or 'lycoris'."
114+
"Default is True. Don't change this setting to False, except if you know exactly what you're doing."
112115
),
113116
},
114117
)

src/peft/tuners/lokr/layer.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,23 @@ def reset_adapter_parameters_random(self, adapter_name: str):
126126
if adapter_name in self.lokr_t2:
127127
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5))
128128

129+
# Initializes weight matrices similar to the way initialized in the LyCORIS repository.
130+
def reset_adapter_parameters_lycoris_way(self, adapter_name):
131+
if adapter_name in self.lokr_w1:
132+
nn.init.kaiming_uniform_(self.lokr_w1[adapter_name], a=math.sqrt(5))
133+
else:
134+
nn.init.kaiming_uniform_(self.lokr_w1_a[adapter_name], a=math.sqrt(5))
135+
nn.init.kaiming_uniform_(self.lokr_w1_b[adapter_name], a=math.sqrt(5))
136+
137+
if adapter_name in self.lokr_w2:
138+
nn.init.zeros_(self.lokr_w2[adapter_name])
139+
else:
140+
nn.init.zeros_(self.lokr_w2_b[adapter_name])
141+
nn.init.kaiming_uniform_(self.lokr_w2_a[adapter_name], a=math.sqrt(5))
142+
143+
if adapter_name in self.lokr_t2:
144+
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5))
145+
129146
def update_layer(
130147
self,
131148
adapter_name: str,
@@ -160,6 +177,7 @@ def update_layer(
160177
self.scaling[adapter_name] = alpha / r
161178
self.rank_dropout[adapter_name] = rank_dropout
162179
self.module_dropout[adapter_name] = module_dropout
180+
self.rank_dropout_scale[adapter_name] = kwargs["rank_dropout_scale"]
163181
base_layer = self.get_base_layer()
164182

165183
# Determine shape of LoKr weights
@@ -192,7 +210,10 @@ def update_layer(
192210

193211
# Initialize weights
194212
if init_weights:
195-
self.reset_adapter_parameters(adapter_name)
213+
if init_weights == "lycoris":
214+
self.reset_adapter_parameters_lycoris_way(adapter_name)
215+
else:
216+
self.reset_adapter_parameters(adapter_name)
196217
else:
197218
self.reset_adapter_parameters_random(adapter_name)
198219

@@ -215,15 +236,16 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
215236
w2 = self.lokr_w2_a[adapter_name] @ self.lokr_w2_b[adapter_name]
216237

217238
# Make weights with Kronecker product
218-
weight = make_kron(w1, w2)
239+
weight = make_kron(w1, w2, self.scaling[adapter_name])
219240
weight = weight.reshape(self.get_base_layer().weight.shape)
220241

221242
# Perform rank dropout during training - drop rows of addition weights
222243
rank_dropout = self.rank_dropout[adapter_name]
223244
if self.training and rank_dropout:
224245
drop = (torch.rand(weight.size(0)) > rank_dropout).float()
225246
drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device)
226-
drop /= drop.mean()
247+
if self.rank_dropout_scale[adapter_name]:
248+
drop /= drop.mean()
227249
weight *= drop
228250

229251
return weight

src/peft/tuners/lokr/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def _create_and_replace(
109109
kwargs = config.to_dict()
110110
kwargs["r"] = config.rank_pattern.get(target_name_key, config.r)
111111
kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha)
112+
kwargs["rank_dropout_scale"] = config.rank_dropout_scale
112113

113114
if isinstance(target, LoKrLayer):
114115
target.update_layer(adapter_name, **kwargs)

src/peft/tuners/lycoris_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(self, base_layer: nn.Module) -> None:
7171
self.alpha = {}
7272
self.scaling = {}
7373
self.rank_dropout = {}
74+
self.rank_dropout_scale = {}
7475
self.module_dropout = {}
7576

7677
# Tuner info

tests/test_initialization.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from peft import (
3333
AdaLoraConfig,
3434
IA3Config,
35+
LoKrConfig,
3536
LoraConfig,
3637
PeftMixedModel,
3738
PeftModel,
@@ -1120,6 +1121,94 @@ def test_lora_use_dora_with_megatron_core_raises(self):
11201121
LoraConfig(target_modules=["linear"], use_dora=True, megatron_config=megatron_config)
11211122

11221123

1124+
class TestLokrInitialization:
1125+
torch_device = infer_device()
1126+
1127+
def get_model(self):
1128+
class MyModule(nn.Module):
1129+
def __init__(self):
1130+
super().__init__()
1131+
# Choose a large weight so that averages are close to expected values.
1132+
self.linear = nn.Linear(1000, 1000)
1133+
self.conv2d = nn.Conv2d(100, 100, 3)
1134+
1135+
def forward(self, x):
1136+
x_4d = x.flatten().reshape(1, 100, 10, 10)
1137+
return self.linear(x), self.conv2d(x_4d)
1138+
1139+
return MyModule().eval().to(self.torch_device)
1140+
1141+
@pytest.fixture
1142+
def data(self):
1143+
return torch.rand(10, 1000).to(self.torch_device)
1144+
1145+
def test_lokr_linear_init_default(self, data):
1146+
torch.manual_seed(0)
1147+
1148+
model = self.get_model()
1149+
output_before = model(data)[0]
1150+
config = LoKrConfig(target_modules=["linear"])
1151+
model = get_peft_model(model, config)
1152+
output_after = model(data)[0]
1153+
1154+
assert torch.allclose(output_before, output_after)
1155+
1156+
def test_lokr_linear_init_false(self, data):
1157+
torch.manual_seed(0)
1158+
1159+
model = self.get_model()
1160+
output_before = model(data)[0]
1161+
config = LoKrConfig(target_modules=["linear"], init_weights=False)
1162+
model = get_peft_model(model, config)
1163+
output_after = model(data)[0]
1164+
1165+
assert not torch.allclose(output_before, output_after)
1166+
1167+
def test_lokr_linear_init_lycoris(self, data):
1168+
torch.manual_seed(0)
1169+
1170+
model = self.get_model()
1171+
output_before = model(data)[0]
1172+
config = LoKrConfig(target_modules=["linear"], init_weights="lycoris")
1173+
model = get_peft_model(model, config)
1174+
output_after = model(data)[0]
1175+
1176+
assert torch.allclose(output_before, output_after)
1177+
1178+
def test_lokr_conv2d_init_default(self, data):
1179+
torch.manual_seed(0)
1180+
1181+
model = self.get_model()
1182+
output_before = model(data)[1]
1183+
config = LoKrConfig(target_modules=["conv2d"])
1184+
model = get_peft_model(model, config)
1185+
output_after = model(data)[1]
1186+
1187+
assert torch.allclose(output_before, output_after)
1188+
1189+
def test_lokr_conv2d_init_false(self, data):
1190+
torch.manual_seed(0)
1191+
1192+
model = self.get_model()
1193+
output_before = model(data)[1]
1194+
config = LoKrConfig(target_modules=["conv2d"], init_weights=False)
1195+
model = get_peft_model(model, config)
1196+
output_after = model(data)[1]
1197+
1198+
assert not torch.allclose(output_before, output_after)
1199+
1200+
def test_lokr_conv2d_init_lycoris(self, data):
1201+
torch.manual_seed(0)
1202+
1203+
model = self.get_model()
1204+
output_before = model(data)[1]
1205+
config = LoKrConfig(target_modules=["conv2d"], init_weights="lycoris")
1206+
model = get_peft_model(model, config)
1207+
output_after = model(data)[1]
1208+
1209+
assert torch.allclose(output_before, output_after)
1210+
1211+
11231212
class TestAdaLoraInitialization:
11241213
torch_device = infer_device()
11251214

0 commit comments

Comments
 (0)