@@ -370,7 +370,7 @@ def test_batching_equivalence(self, atol=1e-5, rtol=1e-5):
370370 LayoutLMv3's relative position bias is incompatible with SDPA/FlashAttention.
371371 """
372372 from transformers .testing_utils import set_config_for_less_flaky_test
373-
373+
374374 def recursive_check (batched_object , single_row_object , model_name , key ):
375375 if isinstance (batched_object , (list , tuple )):
376376 for batched_object_value , single_row_object_value in zip (batched_object , single_row_object ):
@@ -407,12 +407,12 @@ def recursive_check(batched_object, single_row_object, model_name, key):
407407 msg = f"Batched and Single row outputs are not equal in { model_name } for key={ key } .\n \n "
408408 msg += str (e )
409409 raise AssertionError (msg )
410-
410+
411411 config , batched_input = self .model_tester .prepare_config_and_inputs_for_common ()
412412 set_config_for_less_flaky_test (config )
413413 # Force eager attention since LayoutLMv3 has relative position bias enabled by default
414414 config ._attn_implementation = "eager"
415-
415+
416416 for model_class in self .all_model_classes :
417417 config .output_hidden_states = True
418418 model_name = model_class .__name__
@@ -423,8 +423,9 @@ def recursive_check(batched_object, single_row_object, model_name, key):
423423 batched_input_prepared = self ._prepare_for_class (batched_input , model_class )
424424 model = model_class (copy .deepcopy (config )).to (torch_device ).eval ()
425425 from transformers .testing_utils import set_model_for_less_flaky_test
426+
426427 set_model_for_less_flaky_test (model )
427-
428+
428429 batch_size = self .model_tester .batch_size
429430 single_row_input = {}
430431 for key , value in batched_input_prepared .items ():
@@ -433,15 +434,15 @@ def recursive_check(batched_object, single_row_object, model_name, key):
433434 single_row_input [key ] = value [:single_batch_shape ]
434435 else :
435436 single_row_input [key ] = value
436-
437+
437438 with torch .no_grad ():
438439 model_batched_output = model (** batched_input_prepared )
439440 model_row_output = model (** single_row_input )
440-
441+
441442 if isinstance (model_batched_output , torch .Tensor ):
442443 model_batched_output = {"model_output" : model_batched_output }
443444 model_row_output = {"model_output" : model_row_output }
444-
445+
445446 recursive_check (model_batched_output , model_row_output , model_name , "" )
446447
447448 @slow
0 commit comments