Skip to content

Commit cde0624

Browse files
committed
adapted linear layers with base_layer_result
1 parent 45fb929 commit cde0624

File tree

2 files changed

+10
-24
lines changed

2 files changed

+10
-24
lines changed

src/peft/tuners/lora/dora.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=Fals
6262
weight_norm = weight_norm.to("cpu")
6363
self.weight = nn.Parameter(weight_norm, requires_grad=True)
6464

65-
def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
65+
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_layer_result, dropout):
6666
"""
6767
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
6868
output.
6969
"""
70-
lora_result = lora_B(lora_A(x))
70+
lora_result = lora_B(lora_A(dropout(x)))
7171

7272
# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
7373
# calculate the same but using forward.
@@ -86,9 +86,7 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
8686
# during backpropagation"
8787
weight_norm = weight_norm.detach()
8888
mag_norm_scale = (magnitude / weight_norm).view(1, -1)
89-
result_dora = (mag_norm_scale - 1) * (
90-
F.linear(x, transpose(weight, self.fan_in_fan_out))
91-
) + mag_norm_scale * lora_result * scaling
89+
result_dora = (mag_norm_scale - 1) * base_layer_result + mag_norm_scale * lora_result * scaling
9290

9391
# Note: Computation could potentially be accelerated by using the code below instead of calculating X@W again.
9492
# This is only correct if dropout=0, otherwise results will differ:
@@ -142,7 +140,7 @@ def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
142140
weight_norm = weight.norm(p=2, dim=dim, keepdim=True).transpose(1, 0)
143141
return weight_norm
144142

145-
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_layer_result=None):
143+
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_layer_result, dropout):
146144
"""
147145
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
148146
output.
@@ -161,21 +159,8 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_layer_result=N
161159
weight_norm = weight_norm.detach()
162160
mag_norm_scale = magnitude / weight_norm
163161

164-
if isinstance(base_layer_result, torch.Tensor):
165-
# the base layer has already computed the convolution, we do not need to compute it again.
166-
result_dora = (mag_norm_scale - 1) * base_layer_result + mag_norm_scale * lora_B(lora_A(x)) * scaling
167-
else:
168-
result_dora = (mag_norm_scale - 1) * (
169-
F.conv2d(
170-
x,
171-
weight,
172-
bias=None,
173-
stride=base_layer.stride,
174-
padding=base_layer.padding,
175-
dilation=base_layer.dilation,
176-
groups=base_layer.groups,
177-
)
178-
) + mag_norm_scale * lora_B(lora_A(x)) * scaling
162+
# the base layer has already computed the convolution, we do not need to compute it again.
163+
result_dora = (mag_norm_scale - 1) * base_layer_result + mag_norm_scale * lora_B(lora_A(dropout(x))) * scaling
179164

180165
return result_dora
181166

src/peft/tuners/lora/layer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -585,13 +585,14 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
585585
if not self.use_dora[active_adapter]:
586586
result = result + lora_B(lora_A(dropout(x))) * scaling
587587
else:
588-
x = dropout(x)
589588
result = result + self.lora_magnitude_vector[active_adapter](
590589
x,
591590
lora_A=lora_A,
592591
lora_B=lora_B,
593592
scaling=scaling,
594593
base_layer=self.get_base_layer(),
594+
base_layer_result=result,
595+
dropout=dropout
595596
)
596597

597598
result = result.to(torch_result_dtype)
@@ -1120,14 +1121,14 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
11201121
if not self.use_dora[active_adapter]:
11211122
result = result + lora_B(lora_A(dropout(x))) * scaling
11221123
else:
1123-
x = dropout(x)
11241124
result = result + self.lora_magnitude_vector[active_adapter](
11251125
x,
11261126
lora_A=lora_A,
11271127
lora_B=lora_B,
11281128
scaling=scaling,
11291129
base_layer=self.get_base_layer(),
1130-
base_layer_result=result
1130+
base_layer_result=result,
1131+
dropout=dropout
11311132
)
11321133

11331134
result = result.to(torch_result_dtype)

0 commit comments

Comments
 (0)