@@ -62,12 +62,12 @@ def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=Fals
6262 weight_norm = weight_norm .to ("cpu" )
6363 self .weight = nn .Parameter (weight_norm , requires_grad = True )
6464
65- def forward (self , x , * , lora_A , lora_B , scaling , base_layer ):
65+ def forward (self , x , * , lora_A , lora_B , scaling , base_layer , base_layer_result , dropout ):
6666 """
6767 For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
6868 output.
6969 """
70- lora_result = lora_B (lora_A (x ))
70+ lora_result = lora_B (lora_A (dropout ( x ) ))
7171
7272 # Don't use `lora_weight = lora_B.weight @ lora_A.weight` because this causes errors with FSDP. Instead,
7373 # calculate the same but using forward.
@@ -86,9 +86,7 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer):
8686 # during backpropagation"
8787 weight_norm = weight_norm .detach ()
8888 mag_norm_scale = (magnitude / weight_norm ).view (1 , - 1 )
89- result_dora = (mag_norm_scale - 1 ) * (
90- F .linear (x , transpose (weight , self .fan_in_fan_out ))
91- ) + mag_norm_scale * lora_result * scaling
89+ result_dora = (mag_norm_scale - 1 ) * base_layer_result + mag_norm_scale * lora_result * scaling
9290
9391 # Note: Computation could potentially be accelerated by using the code below instead of calculating X@W again.
9492 # This is only correct if dropout=0, otherwise results will differ:
@@ -142,7 +140,7 @@ def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
142140 weight_norm = weight .norm (p = 2 , dim = dim , keepdim = True ).transpose (1 , 0 )
143141 return weight_norm
144142
145- def forward (self , x , * , lora_A , lora_B , scaling , base_layer , base_layer_result = None ):
143+ def forward (self , x , * , lora_A , lora_B , scaling , base_layer , base_layer_result , dropout ):
146144 """
147145 For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer
148146 output.
@@ -161,21 +159,8 @@ def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_layer_result=N
161159 weight_norm = weight_norm .detach ()
162160 mag_norm_scale = magnitude / weight_norm
163161
164- if isinstance (base_layer_result , torch .Tensor ):
165- # the base layer has already computed the convolution, we do not need to compute it again.
166- result_dora = (mag_norm_scale - 1 ) * base_layer_result + mag_norm_scale * lora_B (lora_A (x )) * scaling
167- else :
168- result_dora = (mag_norm_scale - 1 ) * (
169- F .conv2d (
170- x ,
171- weight ,
172- bias = None ,
173- stride = base_layer .stride ,
174- padding = base_layer .padding ,
175- dilation = base_layer .dilation ,
176- groups = base_layer .groups ,
177- )
178- ) + mag_norm_scale * lora_B (lora_A (x )) * scaling
162+ # the base layer has already computed the convolution, we do not need to compute it again.
163+ result_dora = (mag_norm_scale - 1 ) * base_layer_result + mag_norm_scale * lora_B (lora_A (dropout (x ))) * scaling
179164
180165 return result_dora
181166
0 commit comments