Skip to content

Commit 7e11234

Browse files
committed
add test cases for parameters half_precision_model and half_precision_ops to test_re_text_classification
1 parent 4b78a86 commit 7e11234

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

tests/pipeline/test_re_text_classification.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,28 @@ class ExampleDocument(TextDocument):
2020

2121
@pytest.mark.slow
2222
@pytest.mark.parametrize("use_auto", [False, True])
23-
def test_re_text_classification(use_auto):
23+
@pytest.mark.parametrize("half_precision_model", [False, True])
24+
@pytest.mark.parametrize("half_precision_ops", [False, True])
25+
def test_re_text_classification(use_auto, half_precision_model, half_precision_ops):
2426
model_name_or_path = "pie/example-re-textclf-tacred"
2527
if use_auto:
2628
pipeline = AutoPipeline.from_pretrained(
27-
model_name_or_path, taskmodule_kwargs={"create_relation_candidates": True}
29+
model_name_or_path,
30+
taskmodule_kwargs={"create_relation_candidates": True},
31+
half_precision_model=half_precision_model,
2832
)
2933
else:
3034
re_taskmodule = TransformerRETextClassificationTaskModule.from_pretrained(
3135
model_name_or_path,
3236
create_relation_candidates=True,
3337
)
3438
re_model = TransformerTextClassificationModel.from_pretrained(model_name_or_path)
35-
pipeline = Pipeline(model=re_model, taskmodule=re_taskmodule, device=-1)
39+
pipeline = Pipeline(
40+
model=re_model,
41+
taskmodule=re_taskmodule,
42+
device=-1,
43+
half_precision_model=half_precision_model,
44+
)
3645
assert pipeline.taskmodule.is_from_pretrained
3746
assert pipeline.model.is_from_pretrained
3847

@@ -44,7 +53,7 @@ def test_re_text_classification(use_auto):
4453
for start, end, label in [(65, 75, "PER"), (96, 100, "ORG"), (126, 134, "ORG")]:
4554
document.entities.append(LabeledSpan(start=start, end=end, label=label))
4655

47-
pipeline(document, batch_size=2)
56+
pipeline(document, batch_size=2, half_precision_ops=half_precision_ops)
4857
relations: Sequence[BinaryRelation] = document["relations"].predictions
4958
assert len(relations) == 3
5059

0 commit comments

Comments
 (0)