@@ -62,13 +62,11 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=Fals
6262 weight_norm = weight_norm .to ("cpu" )
6363 self .weight = nn .Parameter (weight_norm , requires_grad = True )
6464
65- def forward (self , x , * , lora_A , lora_B , scaling , base_layer ):
65+ def forward (self , x , * , lora_A , lora_B , scaling , base_layer , base_result = None ):
6666 """
6767 For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
6868 output.
6969 """
70- lora_result = lora_B (lora_A (x ))
71-
7270 # Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
7371 # calculate the same but using forward.
7472 x_eye = torch .eye (lora_A .weight .shape [1 ], device = lora_A .weight .device , dtype = x .dtype )
@@ -86,19 +84,18 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
8684 # during backpropagation"
8785 weight_norm = weight_norm .detach ()
8886 mag_norm_scale = (magnitude / weight_norm ).view (1 , - 1 )
89- result_dora = (mag_norm_scale - 1 ) * (
90- F .linear (x , transpose (weight , self .fan_in_fan_out ))
91- ) + mag_norm_scale * lora_result * scaling
92-
93- # Note: Computation could potentially be accelerated by using the code below instead of calculating X@W again.
94- # This is only correct if dropout=0, otherwise results will differ:
95- # https://github.com/huggingface/peft/pull/1474#issuecomment-1964682771
96- # bias = self.get_base_layer().bias
97- # if bias is not None:
98- # result = result - bias
99- # result = mag_norm_scale * result + mag_norm_scale * lora_B(lora_A(x)) * scaling
100- # if bias is not None:
101- # result = result + bias
87+
88+ lora_result = lora_B (lora_A (x ))
89+
90+ bias = None
91+ if base_result is not None :
92+ bias = base_layer .bias
93+ if bias is not None :
94+ base_result = base_result - bias
95+ else :
96+ base_result = F .linear (x , transpose (weight , self .fan_in_fan_out ))
97+
98+ result_dora = (mag_norm_scale - 1 ) * base_result + mag_norm_scale * lora_result * scaling
10299
103100 return result_dora
104101
0 commit comments