@@ -175,7 +175,7 @@ def construct(
175
175
176
176
if attention_mask is not None :
177
177
# Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
178
- attention_scores = attention_scores + attention_mask .astype (self . dense_dtype )
178
+ attention_scores = attention_scores + attention_mask .astype (attention_scores . dtype )
179
179
180
180
# Normalize the attention scores to probabilities.
181
181
# Use the trick of the CogView paper to stablize training
@@ -227,11 +227,8 @@ def __init__(self, config):
227
227
self .has_relative_attention_bias = config .has_relative_attention_bias
228
228
self .has_spatial_attention_bias = config .has_spatial_attention_bias
229
229
self .patch_size = config .patch_size
230
- self .use_float16 = config .use_float16
231
- self .dense_dtype = mstype .float32
232
- if self .use_float16 is True :
233
- self .dense_dtype = mstype .float16
234
- self .min = finfo (self .dense_dtype )
230
+ self .float32_min = finfo (mstype .float32 )
231
+ self .float16_min = finfo (mstype .float16 )
235
232
self .out_channels = 1
236
233
self .use_visual_backbone = True
237
234
@@ -342,7 +339,13 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape, dtype
342
339
# Since we are adding it to the raw scores before the softmax, this is
343
340
# effectively the same as removing these entirely. # fp16 compatibility
344
341
extended_attention_mask = extended_attention_mask .astype (dtype )
345
- extended_attention_mask = (1.0 - extended_attention_mask ) * self .min
342
+
343
+ if dtype == mstype .float32 :
344
+ minimum = self .float32_min
345
+ elif dtype == mstype .float16 :
346
+ minimum = self .float16_min
347
+
348
+ extended_attention_mask = (1.0 - extended_attention_mask ) * minimum
346
349
return extended_attention_mask
347
350
348
351
def get_head_mask (self , head_mask , num_hidden_layers : int , is_attention_chunked : bool = False ):
@@ -518,7 +521,7 @@ def construct(
518
521
519
522
520
523
@register_backbone
521
- def layoutlmv3 (use_float16 : bool = True , ** kwargs ):
522
- pretrained_config = LayoutLMv3PretrainedConfig (use_float16 )
524
+ def layoutlmv3 (** kwargs ):
525
+ pretrained_config = LayoutLMv3PretrainedConfig ()
523
526
model = LayoutLMv3Model (pretrained_config )
524
527
return model
0 commit comments