Skip to content

Commit 1f88f40

Browse files
committed
check
1 parent 223bfe2 commit 1f88f40

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

.circleci/create_circleci_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def to_dict(self):
154154
},
155155
{"run": {
156156
"name": "Run tests",
157-
"command": f"({timeout_cmd} python3 -m pytest {marker_cmd} -n 1 {additional_flags} {' '.join(pytest_flags)} tests/models/flaubert/test_modeling_flaubert.py::FlaubertModelTest::test_batching_equivalence | tee tests_output.txt)"}
157+
"command": f"({timeout_cmd} python3 -m pytest {marker_cmd} -n 1 {additional_flags} {' '.join(pytest_flags)} tests/models -k test_batching_equivalence | tee tests_output.txt)"}
158158
},
159159
{"run": {"name": "Expand to show skipped tests", "when": "always", "command": f"python3 .circleci/parse_test_outputs.py --file tests_output.txt --skip"}},
160160
{"run": {"name": "Failed tests: show reasons", "when": "always", "command": f"python3 .circleci/parse_test_outputs.py --file tests_output.txt --fail"}},

tests/test_modeling_common.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,13 @@ def recursive_check(batched_object, single_row_object, model_name, key):
808808
self.assertFalse(
809809
torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}"
810810
)
811+
a = torch.amax(torch.abs(batched_row))
812+
b = torch.amax(torch.abs(single_row_object))
813+
if torch.is_floating_point(a) and torch.is_floating_point(b):
814+
if a < 1e-9 or b < 1e-9:
815+
raise ValueError("hello")
816+
# breakpoint()
817+
return
811818
self.assertTrue(
812819
(equivalence(batched_row, single_row_object)) <= 1e-03,
813820
msg=(
@@ -819,18 +826,18 @@ def recursive_check(batched_object, single_row_object, model_name, key):
819826
config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
820827
equivalence = get_tensor_equivalence_function(batched_input)
821828

822-
set_model_tester_for_less_flaky_test(self)
829+
#set_model_tester_for_less_flaky_test(self)
823830

824831
for model_class in self.all_model_classes:
825832
config.output_hidden_states = True
826-
set_config_for_less_flaky_test(config)
833+
#set_config_for_less_flaky_test(config)
827834

828835
model_name = model_class.__name__
829836
if hasattr(self.model_tester, "prepare_config_and_inputs_for_model_class"):
830837
config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class)
831838
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
832839
model = model_class(config).to(torch_device).eval()
833-
set_model_for_less_flaky_test(model)
840+
#set_model_for_less_flaky_test(model)
834841

835842
batch_size = self.model_tester.batch_size
836843
single_row_input = {}

0 commit comments

Comments
 (0)