@@ -139,7 +139,7 @@ def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
139139 weight_norm = weight .norm (p = 2 , dim = dim , keepdim = True ).transpose (1 , 0 )
140140 return weight_norm
141141
142- def forward (self , x , * , lora_A , lora_B , scaling , base_layer ):
142+ def forward (self , x , * , lora_A , lora_B , scaling , base_layer , base_result = None ):
143143 """
144144 For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
145145 output.
@@ -157,8 +157,9 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
157157 # during backpropagation"
158158 weight_norm = weight_norm .detach ()
159159 mag_norm_scale = magnitude / weight_norm
160- result_dora = (mag_norm_scale - 1 ) * (
161- self .conv_fn (
160+
161+ if base_result is None :
162+ base_result = self .conv_fn (
162163 x ,
163164 weight ,
164165 bias = None ,
@@ -167,8 +168,14 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
167168 dilation = base_layer .dilation ,
168169 groups = base_layer .groups ,
169170 )
170- ) + mag_norm_scale * lora_B (lora_A (x )) * scaling
171+ else :
172+ bias = base_layer .bias
173+ if bias is not None :
174+ # reshape bias to (1, -1, 1, ...)
175+ bias_shape = (1 , - 1 ) + (1 ,) * (base_result .dim () - 2 )
176+ base_result = base_result - bias .view (* bias_shape )
171177
178+ result_dora = (mag_norm_scale - 1 ) * base_result + mag_norm_scale * lora_B (lora_A (x )) * scaling
172179 return result_dora
173180
174181 def __repr__ (self ) -> str :
0 commit comments