Skip to content

Commit 4a59a80

Browse files
ENH Orthogonal LoRA layer initialization (2)
Continuation of, and supersedes huggingface#2389 Check discussion there for further info.
1 parent 4c82bff commit 4a59a80

File tree

3 files changed

+71
-4
lines changed

3 files changed

+71
-4
lines changed

src/peft/tuners/lora/config.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ class LoraConfig(PeftConfig):
233233
Otherwise, it will use the original default value of `lora_alpha/r`.
234234
modules_to_save (`List[str]`):
235235
List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
236-
init_lora_weights (`bool` | `Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]`):
236+
init_lora_weights (`bool` | `Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq", "orthogonal"]`):
237237
How to initialize the weights of the adapter layers. Passing True (default) results in the default
238238
initialization from the reference implementation from Microsoft, with the LoRA B weight being set to 0.
239239
This means that without further training, the LoRA adapter will be a no-op. Setting the initialization to
@@ -252,7 +252,9 @@ class LoraConfig(PeftConfig):
252252
a 7B model within seconds, and the training effect is approximately equivalent to using SVD. Passing
253253
`'corda'` results in the initialization of <ahref='https://arxiv.org/abs/2406.05223' >Context-Oriented
254254
Decomposition Adaptation</a>, which converges even more rapidly than PiSSA in Instruction-Previewed Mode,
255-
and preserves world knowledge better than LoRA in Knowledge-Preserved Mode.
255+
and preserves world knowledge better than LoRA in Knowledge-Preserved Mode. Passing `"orthogonal"` results
256+
in LoRA A and B being intialized orthogonally; in this, it resembles `"olora"`, but the base weights are
257+
left untouched (requires `r` to be even, only supported for linear layers for now).
256258
layers_to_transform (`Union[List[int], int]`):
257259
The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices
258260
that are specified in this list. If a single integer is passed, it will apply the transformations on the
@@ -356,7 +358,8 @@ class LoraConfig(PeftConfig):
356358
},
357359
)
358360
init_lora_weights: (
359-
bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]
361+
bool
362+
| Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq", "orthogonal"]
360363
) = field(
361364
default=True,
362365
metadata={
@@ -375,7 +378,8 @@ class LoraConfig(PeftConfig):
375378
"[number of iters] indicates the number of subspace iterations to perform fsvd, and must be a "
376379
"nonnegative integer. "
377380
"Passing `'corda'` results in CorDA initialization. "
378-
"Pass `'loftq'` to use LoftQ initialization."
381+
"Pass `'loftq'` to use LoftQ initialization. "
382+
"Pass `'orthogonal'` for orthogonal initialization of LoRA A and B."
379383
),
380384
},
381385
)

src/peft/tuners/lora/layer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,9 @@ def update_layer(
230230
self.loftq_init(adapter_name)
231231
elif init_lora_weights == "eva":
232232
nn.init.zeros_(self.lora_B[adapter_name].weight)
233+
elif init_lora_weights == "orthogonal":
234+
with gather_params_ctx(self.get_base_layer().weight):
235+
self.orthogonal_init(adapter_name)
233236
elif init_lora_weights:
234237
self.reset_lora_parameters(adapter_name, init_lora_weights)
235238
# call this before init of the lora variants
@@ -440,6 +443,23 @@ def loftq_init(self, adapter_name):
440443
self.lora_embedding_B[adapter_name].weight.data = lora_B
441444
self.get_base_layer().weight.data = qweight
442445

446+
@torch.no_grad()
447+
def orthogonal_init(self, adapter_name):
448+
# https://datta0.github.io/posts/rethink-lora-init/#orthogonal-initialisation
449+
rank = self.r[adapter_name]
450+
if rank % 2 != 0:
451+
raise ValueError(f"Orthogonal initialization requires the LoRA rank to be even, got {rank} instead.")
452+
453+
X = torch.randn(rank, rank)
454+
Q, _ = torch.linalg.qr(X)
455+
q_odd = Q[0::2, :] # Odd rows
456+
q_even = Q[1::2, :] # Even rows
457+
dtype = self.get_base_layer().weight.dtype
458+
lora_A = torch.randn(self.in_features, rank // 2).mm(q_odd).T / 10.0
459+
lora_B = torch.randn(rank // 2, self.out_features).T.mm(q_even) / 10.0
460+
self.lora_A[adapter_name].weight = nn.Parameter(lora_A.contiguous().to(dtype))
461+
self.lora_B[adapter_name].weight = nn.Parameter(lora_B.contiguous().to(dtype))
462+
443463
def _cache_store(self, key: str, value: Any) -> None:
444464
self._caches[key] = value
445465

tests/test_initialization.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,48 @@ def test_lora_conv2d_false(self):
277277
# as long as they are not zero, in order to avoid identity transformation.
278278
assert not torch.allclose(weight_B, torch.zeros_like(weight_B))
279279

280+
def test_lora_init_orthogonal(self):
281+
torch.manual_seed(0)
282+
283+
model = self.get_model()
284+
config = LoraConfig(target_modules=["linear"], init_lora_weights="orthogonal")
285+
model = get_peft_model(model, config)
286+
287+
weight_A = model.linear.lora_A["default"].weight
288+
weight_B = model.linear.lora_B["default"].weight
289+
290+
assert not torch.allclose(weight_A, torch.zeros_like(weight_A))
291+
assert not torch.allclose(weight_B, torch.zeros_like(weight_B))
292+
assert (weight_B @ weight_A).abs().max() < 1e-6
293+
294+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
295+
def test_lora_init_orthogonal_half_precision_dtype(self, dtype):
296+
try:
297+
torch.zeros(1, dtype=dtype)
298+
except Exception:
299+
pytest.skip(f"dtype {dtype} not supported on this system, skipping test")
300+
301+
torch.manual_seed(0)
302+
303+
model = self.get_model()
304+
config = LoraConfig(target_modules=["linear"], init_lora_weights="orthogonal")
305+
model = get_peft_model(model, config).to(dtype)
306+
307+
weight_A = model.linear.lora_A["default"].weight
308+
weight_B = model.linear.lora_B["default"].weight
309+
310+
assert weight_A.dtype == dtype
311+
assert weight_B.dtype == dtype
312+
313+
def test_lora_init_orthogonal_odd_rank_raises(self):
314+
torch.manual_seed(0)
315+
316+
model = self.get_model()
317+
config = LoraConfig(target_modules=["linear"], init_lora_weights="orthogonal", r=7)
318+
msg = "Orthogonal initialization requires the LoRA rank to be even, got 7 instead."
319+
with pytest.raises(ValueError, match=msg):
320+
get_peft_model(model, config)
321+
280322
def test_lora_scaling_default(self):
281323
# default is True
282324
torch.manual_seed(0)
@@ -1254,6 +1296,7 @@ def test_lora_with_bias_embedding_raises(self):
12541296
{"init_lora_weights": "olora"},
12551297
{"init_lora_weights": "pissa"},
12561298
{"init_lora_weights": "pissa_niter_3"},
1299+
{"init_lora_weights": "orthogonal"},
12571300
],
12581301
)
12591302
def test_lora_with_bias_incompatible_arguments(self, extra_kwargs):

0 commit comments

Comments
 (0)