Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize DoRA in eval and no dropout #2122

Merged
merged 17 commits into from
Oct 16, 2024
34 changes: 18 additions & 16 deletions src/peft/tuners/lora/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,11 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=Fals
weight_norm = weight_norm.to("cpu")
self.weight = nn.Parameter(weight_norm, requires_grad=True)

def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None):
"""
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
output.
"""
lora_result = lora_B(lora_A(x))

# Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
# calculate the same but using forward.
x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device, dtype=x.dtype)
Expand All @@ -86,19 +84,23 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
# during backpropagation"
weight_norm = weight_norm.detach()
mag_norm_scale = (magnitude / weight_norm).view(1, -1)
result_dora = (mag_norm_scale - 1) * (
F.linear(x, transpose(weight, self.fan_in_fan_out))
) + mag_norm_scale * lora_result * scaling

# Note: Computation could potentially be accelerated by using the code below instead of calculating X@W again.
# This is only correct if dropout=0, otherwise results will differ:
# https://github.com/huggingface/peft/pull/1474#issuecomment-1964682771
# bias = self.get_base_layer().bias
# if bias is not None:
# result = result - bias
# result = mag_norm_scale * result + mag_norm_scale * lora_B(lora_A(x)) * scaling
# if bias is not None:
# result = result + bias

lora_result = lora_B(lora_A(x))

if base_result is not None:
# `base_result` is provided only if dropout is set to 0 or if the model is in evaluation mode.
# This means we already have a deterministic output from the base layer that can be reused.
bias = base_layer.bias
if bias is not None:
base_result = base_result - bias
result_dora = mag_norm_scale * base_result + mag_norm_scale * lora_result * scaling
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, wait, should this not be the exact same calculation as in line 103? I.e. we should leave the condition after calculating the base_result and then do the same calculation of dora_result for both cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that I follow. With the base_result in place:

  1. We first subtract the bias
  2. Compute the dora_result where the scale the base_result with mag_norm_scale

But without the base_result:

  1. We compute the base_result with the linear forward
  2. Compute the dora_result where we scale the base_result with (1 - mag_norm_scale)

Aren't they going to be different for each case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so I'm a bit confused, let's try to resolve this.

In the old code, we basically have:

dora_result = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * lora_scale

variable names slightly changed for clarity

My thinking is that the base_result is either calculated right there (old code) or we use the base_result that is being passed as an argument, but the basic equation stays the same.

Of course, as you correctly noted, the bias needs to be subtracted first and then added back in the latter case.

In the currently proposed code, in one case we calculate mag_norm_scale * base_result and in the other (mag_norm_scale - 1) * base_result. This looks inconsistent to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @BenjaminBossan

I have made the changes as suggested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the change is still not correct. Here is my suggestion:

        bias = None
        if base_result is not None:
            bias = base_layer.bias
            if bias is not None:
                result = result - bias
        else:
            base_result = F.linear(x, transpose(weight, self.fan_in_fan_out))

        result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling

        if bias is not None:
            result_dora = result + bias

This way, if base_result = None, the computation is exactly the same as it was previously.

I believe the confusion may stem from my comment:

        # result = mag_norm_scale * result + mag_norm_scale * lora_B(lora_A(x)) * scaling

This comment should have been:

        # result = (mag_norm_scale - 1) * result + mag_norm_scale * lora_B(lora_A(x)) * scaling

Does that make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense!

About the unit test -- do you want me to add a new test file? Or add a test somewhere?

if bias is not None:
result_dora = result_dora + bias
else:
# If `base_result` is not provided (likely due to dropout being used or training mode),
# calculate it directly using the base layer weights.
base_result = F.linear(x, transpose(weight, self.fan_in_fan_out))
result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling

return result_dora

Expand Down
27 changes: 19 additions & 8 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,14 +585,25 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
x = dropout(x)
result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
)
if isinstance(dropout, nn.Identity) or not self.training:
result = self.lora_magnitude_vector[active_adapter](
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
x,
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_result=result,
)
else:
x = dropout(x)
result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
result=None,
)

result = result.to(torch_result_dtype)

Expand Down
Loading