@@ -808,6 +808,13 @@ def recursive_check(batched_object, single_row_object, model_name, key):
808
808
self .assertFalse (
809
809
torch .isinf (single_row_object ).any (), f"Single row output has `inf` in { model_name } for key={ key } "
810
810
)
811
+ a = torch .amax (torch .abs (batched_row ))
812
+ b = torch .amax (torch .abs (single_row_object ))
813
+ if torch .is_floating_point (a ) and torch .is_floating_point (b ):
814
+ if a < 1e-9 or b < 1e-9 :
815
+ raise ValueError ("hello" )
816
+ # breakpoint()
817
+ return
811
818
self .assertTrue (
812
819
(equivalence (batched_row , single_row_object )) <= 1e-03 ,
813
820
msg = (
@@ -819,18 +826,18 @@ def recursive_check(batched_object, single_row_object, model_name, key):
819
826
config , batched_input = self .model_tester .prepare_config_and_inputs_for_common ()
820
827
equivalence = get_tensor_equivalence_function (batched_input )
821
828
822
- set_model_tester_for_less_flaky_test (self )
829
+ # set_model_tester_for_less_flaky_test(self)
823
830
824
831
for model_class in self .all_model_classes :
825
832
config .output_hidden_states = True
826
- set_config_for_less_flaky_test (config )
833
+ # set_config_for_less_flaky_test(config)
827
834
828
835
model_name = model_class .__name__
829
836
if hasattr (self .model_tester , "prepare_config_and_inputs_for_model_class" ):
830
837
config , batched_input = self .model_tester .prepare_config_and_inputs_for_model_class (model_class )
831
838
batched_input_prepared = self ._prepare_for_class (batched_input , model_class )
832
839
model = model_class (config ).to (torch_device ).eval ()
833
- set_model_for_less_flaky_test (model )
840
+ # set_model_for_less_flaky_test(model)
834
841
835
842
batch_size = self .model_tester .batch_size
836
843
single_row_input = {}
0 commit comments