@@ -20,19 +20,28 @@ class ExampleDocument(TextDocument):
2020
2121@pytest .mark .slow
2222@pytest .mark .parametrize ("use_auto" , [False , True ])
23- def test_re_text_classification (use_auto ):
23+ @pytest .mark .parametrize ("half_precision_model" , [False , True ])
24+ @pytest .mark .parametrize ("half_precision_ops" , [False , True ])
25+ def test_re_text_classification (use_auto , half_precision_model , half_precision_ops ):
2426 model_name_or_path = "pie/example-re-textclf-tacred"
2527 if use_auto :
2628 pipeline = AutoPipeline .from_pretrained (
27- model_name_or_path , taskmodule_kwargs = {"create_relation_candidates" : True }
29+ model_name_or_path ,
30+ taskmodule_kwargs = {"create_relation_candidates" : True },
31+ half_precision_model = half_precision_model ,
2832 )
2933 else :
3034 re_taskmodule = TransformerRETextClassificationTaskModule .from_pretrained (
3135 model_name_or_path ,
3236 create_relation_candidates = True ,
3337 )
3438 re_model = TransformerTextClassificationModel .from_pretrained (model_name_or_path )
35- pipeline = Pipeline (model = re_model , taskmodule = re_taskmodule , device = - 1 )
39+ pipeline = Pipeline (
40+ model = re_model ,
41+ taskmodule = re_taskmodule ,
42+ device = - 1 ,
43+ half_precision_model = half_precision_model ,
44+ )
3645 assert pipeline .taskmodule .is_from_pretrained
3746 assert pipeline .model .is_from_pretrained
3847
@@ -44,7 +53,7 @@ def test_re_text_classification(use_auto):
4453 for start , end , label in [(65 , 75 , "PER" ), (96 , 100 , "ORG" ), (126 , 134 , "ORG" )]:
4554 document .entities .append (LabeledSpan (start = start , end = end , label = label ))
4655
47- pipeline (document , batch_size = 2 )
56+ pipeline (document , batch_size = 2 , half_precision_ops = half_precision_ops )
4857 relations : Sequence [BinaryRelation ] = document ["relations" ].predictions
4958 assert len (relations ) == 3
5059
0 commit comments