Skip to content

Commit 39f0a31

Browse files
committed
Fix code quality: remove whitespace from blank lines and format test file
1 parent f203f87 commit 39f0a31

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

tests/models/layoutlmv3/test_modeling_layoutlmv3.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def test_batching_equivalence(self, atol=1e-5, rtol=1e-5):
370370
LayoutLMv3's relative position bias is incompatible with SDPA/FlashAttention.
371371
"""
372372
from transformers.testing_utils import set_config_for_less_flaky_test
373-
373+
374374
def recursive_check(batched_object, single_row_object, model_name, key):
375375
if isinstance(batched_object, (list, tuple)):
376376
for batched_object_value, single_row_object_value in zip(batched_object, single_row_object):
@@ -407,12 +407,12 @@ def recursive_check(batched_object, single_row_object, model_name, key):
407407
msg = f"Batched and Single row outputs are not equal in {model_name} for key={key}.\n\n"
408408
msg += str(e)
409409
raise AssertionError(msg)
410-
410+
411411
config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
412412
set_config_for_less_flaky_test(config)
413413
# Force eager attention since LayoutLMv3 has relative position bias enabled by default
414414
config._attn_implementation = "eager"
415-
415+
416416
for model_class in self.all_model_classes:
417417
config.output_hidden_states = True
418418
model_name = model_class.__name__
@@ -423,8 +423,9 @@ def recursive_check(batched_object, single_row_object, model_name, key):
423423
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
424424
model = model_class(copy.deepcopy(config)).to(torch_device).eval()
425425
from transformers.testing_utils import set_model_for_less_flaky_test
426+
426427
set_model_for_less_flaky_test(model)
427-
428+
428429
batch_size = self.model_tester.batch_size
429430
single_row_input = {}
430431
for key, value in batched_input_prepared.items():
@@ -433,15 +434,15 @@ def recursive_check(batched_object, single_row_object, model_name, key):
433434
single_row_input[key] = value[:single_batch_shape]
434435
else:
435436
single_row_input[key] = value
436-
437+
437438
with torch.no_grad():
438439
model_batched_output = model(**batched_input_prepared)
439440
model_row_output = model(**single_row_input)
440-
441+
441442
if isinstance(model_batched_output, torch.Tensor):
442443
model_batched_output = {"model_output": model_batched_output}
443444
model_row_output = {"model_output": model_row_output}
444-
445+
445446
recursive_check(model_batched_output, model_row_output, model_name, "")
446447

447448
@slow

0 commit comments

Comments
 (0)