Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Removed duplicate convolution for DoRA #2153

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 7 additions & 17 deletions src/peft/tuners/lora/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=Fals
weight_norm = weight_norm.to("cpu")
self.weight = nn.Parameter(weight_norm, requires_grad=True)

def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_layer_result, dropout):
"""
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
output.
"""
lora_result = lora_B(lora_A(x))
lora_result = lora_B(lora_A(dropout(x)))

# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
# calculate the same but using forward.
Expand All @@ -86,9 +86,7 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
# during backpropagation"
weight_norm = weight_norm.detach()
mag_norm_scale = (magnitude / weight_norm).view(1, -1)
result_dora = (mag_norm_scale - 1) * (
F.linear(x, transpose(weight, self.fan_in_fan_out))
) + mag_norm_scale * lora_result * scaling
result_dora = (mag_norm_scale - 1) * base_layer_result + mag_norm_scale * lora_result * scaling

# Note: Computation could potentially be accelerated by using the code below instead of calculating X@W again.
# This is only correct if dropout=0, otherwise results will differ:
Expand Down Expand Up @@ -142,7 +140,7 @@ def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
weight_norm = weight.norm(p=2, dim=dim, keepdim=True).transpose(1, 0)
return weight_norm

def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_layer_result, dropout):
"""
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
output.
Expand All @@ -160,17 +158,9 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
# during backpropagation"
weight_norm = weight_norm.detach()
mag_norm_scale = magnitude / weight_norm
result_dora = (mag_norm_scale - 1) * (
self.conv_fn(
x,
weight,
bias=None,
stride=base_layer.stride,
padding=base_layer.padding,
dilation=base_layer.dilation,
groups=base_layer.groups,
)
) + mag_norm_scale * lora_B(lora_A(x)) * scaling

# the base layer has already computed the convolution, we do not need to compute it again.
result_dora = (mag_norm_scale - 1) * base_layer_result + mag_norm_scale * lora_B(lora_A(dropout(x))) * scaling

return result_dora

Expand Down
8 changes: 5 additions & 3 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,13 +585,14 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
x = dropout(x)
result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_layer_result=result,
dropout=dropout
)

result = result.to(torch_result_dtype)
Expand Down Expand Up @@ -904,6 +905,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
out_kernel = out_stride = (1,) * (self._kernel_dim - 2)
self.lora_A[adapter_name] = conv_layer(self.in_features, r, kernel_size, stride, padding, bias=False)
self.lora_B[adapter_name] = conv_layer(r, self.out_features, out_kernel, out_stride, bias=False)

if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
Expand Down Expand Up @@ -1088,7 +1090,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)

if self.disable_adapters:
if self.merged:
self.unmerge()
Expand All @@ -1113,13 +1114,14 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
x = dropout(x)
result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_layer_result=result,
dropout=dropout
)

result = result.to(torch_result_dtype)
Expand Down