Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/peft/tuners/lora/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
weight_norm = weight.norm(p=2, dim=dim, keepdim=True).transpose(1, 0)
return weight_norm

def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None):
"""
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
output.
Expand All @@ -157,8 +157,9 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
# during backpropagation"
weight_norm = weight_norm.detach()
mag_norm_scale = magnitude / weight_norm
result_dora = (mag_norm_scale - 1) * (
self.conv_fn(

if base_result is None:
base_result = self.conv_fn(
x,
weight,
bias=None,
Expand All @@ -167,8 +168,8 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
dilation=base_layer.dilation,
groups=base_layer.groups,
)
) + mag_norm_scale * lora_B(lora_A(x)) * scaling

result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_B(lora_A(x)) * scaling
return result_dora

def __repr__(self) -> str:
Expand Down
9 changes: 8 additions & 1 deletion src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)

else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
Expand All @@ -1288,13 +1289,19 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
x = dropout(x)
if isinstance(dropout, nn.Identity) or not self.training:
base_result = result
else:
x = dropout(x)
base_result = None

result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_result=base_result,
)

result = result.to(torch_result_dtype)
Expand Down