Skip to content

Commit ad86edf

Browse files
BenjaminBossanwinglian
authored andcommitted
ENH Orthogonal LoRA layer initialization (2) (huggingface#2498)
Continuation of, and supersedes, huggingface#2389 Check discussion there for further info. --------- Co-authored-by: Wing Lian <[email protected]>
1 parent 79f7213 commit ad86edf

File tree

3 files changed

+71
-4
lines changed

3 files changed

+71
-4
lines changed

src/peft/tuners/lora/config.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ class LoraConfig(PeftConfig):
233233
use the original default value of `lora_alpha/r`.
234234
modules_to_save (`List[str]`):
235235
List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
236-
init_lora_weights (`bool` | `Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]`):
236+
init_lora_weights (`bool` | `Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq", "orthogonal"]`):
237237
How to initialize the weights of the adapter layers. Passing True (default) results in the default
238238
initialization from the reference implementation from Microsoft, with the LoRA B weight being set to 0.
239239
This means that without further training, the LoRA adapter will be a no-op. Setting the initialization to
@@ -253,7 +253,9 @@ class LoraConfig(PeftConfig):
253253
using SVD. Passing `'corda'` results in the initialization of <a
254254
href='https://huggingface.co/papers/2406.05223' >Context-Oriented Decomposition Adaptation</a>, which
255255
converges even more rapidly than PiSSA in Instruction-Previewed Mode, and preserves world knowledge better
256-
than LoRA in Knowledge-Preserved Mode.
256+
than LoRA in Knowledge-Preserved Mode. Passing `"orthogonal"` results in LoRA A and B being intialized
257+
orthogonally; in this, it resembles `"olora"`, but the base weights are left untouched (requires `r` to be
258+
even, only supported for linear layers for now).
257259
layers_to_transform (`Union[List[int], int]`):
258260
The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices
259261
that are specified in this list. If a single integer is passed, it will apply the transformations on the
@@ -357,7 +359,8 @@ class LoraConfig(PeftConfig):
357359
},
358360
)
359361
init_lora_weights: (
360-
bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]
362+
bool
363+
| Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq", "orthogonal"]
361364
) = field(
362365
default=True,
363366
metadata={
@@ -376,7 +379,8 @@ class LoraConfig(PeftConfig):
376379
"[number of iters] indicates the number of subspace iterations to perform fsvd, and must be a "
377380
"nonnegative integer. "
378381
"Passing `'corda'` results in CorDA initialization. "
379-
"Pass `'loftq'` to use LoftQ initialization."
382+
"Pass `'loftq'` to use LoftQ initialization. "
383+
"Pass `'orthogonal'` for orthogonal initialization of LoRA A and B."
380384
),
381385
},
382386
)

src/peft/tuners/lora/layer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,9 @@ def update_layer(
233233
self.loftq_init(adapter_name)
234234
elif init_lora_weights == "eva":
235235
nn.init.zeros_(self.lora_B[adapter_name].weight)
236+
elif init_lora_weights == "orthogonal":
237+
with gather_params_ctx(self.get_base_layer().weight):
238+
self.orthogonal_init(adapter_name)
236239
elif init_lora_weights:
237240
self.reset_lora_parameters(adapter_name, init_lora_weights)
238241
# call this before init of the lora variants
@@ -443,6 +446,23 @@ def loftq_init(self, adapter_name):
443446
self.lora_embedding_B[adapter_name].weight.data = lora_B
444447
self.get_base_layer().weight.data = qweight
445448

449+
@torch.no_grad()
450+
def orthogonal_init(self, adapter_name):
451+
# https://datta0.github.io/posts/rethink-lora-init/#orthogonal-initialisation
452+
rank = self.r[adapter_name]
453+
if rank % 2 != 0:
454+
raise ValueError(f"Orthogonal initialization requires the LoRA rank to be even, got {rank} instead.")
455+
456+
X = torch.randn(rank, rank)
457+
Q, _ = torch.linalg.qr(X)
458+
q_odd = Q[0::2, :] # Odd rows
459+
q_even = Q[1::2, :] # Even rows
460+
dtype = self.get_base_layer().weight.dtype
461+
lora_A = torch.randn(self.in_features, rank // 2).mm(q_odd).T / 10.0
462+
lora_B = torch.randn(rank // 2, self.out_features).T.mm(q_even) / 10.0
463+
self.lora_A[adapter_name].weight = nn.Parameter(lora_A.contiguous().to(dtype))
464+
self.lora_B[adapter_name].weight = nn.Parameter(lora_B.contiguous().to(dtype))
465+
446466
def _cache_store(self, key: str, value: Any) -> None:
447467
self._caches[key] = value
448468

tests/test_initialization.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,48 @@ def test_lora_conv2d_false(self):
277277
# as long as they are not zero, in order to avoid identity transformation.
278278
assert not torch.allclose(weight_B, torch.zeros_like(weight_B))
279279

280+
def test_lora_init_orthogonal(self):
281+
torch.manual_seed(0)
282+
283+
model = self.get_model()
284+
config = LoraConfig(target_modules=["linear"], init_lora_weights="orthogonal")
285+
model = get_peft_model(model, config)
286+
287+
weight_A = model.linear.lora_A["default"].weight
288+
weight_B = model.linear.lora_B["default"].weight
289+
290+
assert not torch.allclose(weight_A, torch.zeros_like(weight_A))
291+
assert not torch.allclose(weight_B, torch.zeros_like(weight_B))
292+
assert (weight_B @ weight_A).abs().max() < 1e-6
293+
294+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
295+
def test_lora_init_orthogonal_half_precision_dtype(self, dtype):
296+
try:
297+
torch.zeros(1, dtype=dtype)
298+
except Exception:
299+
pytest.skip(f"dtype {dtype} not supported on this system, skipping test")
300+
301+
torch.manual_seed(0)
302+
303+
model = self.get_model()
304+
config = LoraConfig(target_modules=["linear"], init_lora_weights="orthogonal")
305+
model = get_peft_model(model, config).to(dtype)
306+
307+
weight_A = model.linear.lora_A["default"].weight
308+
weight_B = model.linear.lora_B["default"].weight
309+
310+
assert weight_A.dtype == dtype
311+
assert weight_B.dtype == dtype
312+
313+
def test_lora_init_orthogonal_odd_rank_raises(self):
314+
torch.manual_seed(0)
315+
316+
model = self.get_model()
317+
config = LoraConfig(target_modules=["linear"], init_lora_weights="orthogonal", r=7)
318+
msg = "Orthogonal initialization requires the LoRA rank to be even, got 7 instead."
319+
with pytest.raises(ValueError, match=msg):
320+
get_peft_model(model, config)
321+
280322
def test_lora_scaling_default(self):
281323
# default is True
282324
torch.manual_seed(0)
@@ -1255,6 +1297,7 @@ def test_lora_with_bias_embedding_raises(self):
12551297
{"init_lora_weights": "olora"},
12561298
{"init_lora_weights": "pissa"},
12571299
{"init_lora_weights": "pissa_niter_3"},
1300+
{"init_lora_weights": "orthogonal"},
12581301
],
12591302
)
12601303
def test_lora_with_bias_incompatible_arguments(self, extra_kwargs):

0 commit comments

Comments
 (0)