11from dataclasses import dataclass
2- from typing import Sequence
32
43import pytest
4+ import torch
55from pie_core import AnnotationLayer , annotation_field
66
77from pytorch_ie import AutoPipeline
1111from pytorch_ie .pipeline import Pipeline
1212from pytorch_ie .taskmodules import TransformerRETextClassificationTaskModule
1313
14+ torch .use_deterministic_algorithms (True )
15+
1416
1517@dataclass
1618class ExampleDocument (TextDocument ):
@@ -22,7 +24,9 @@ class ExampleDocument(TextDocument):
2224@pytest .mark .parametrize ("use_auto" , [False , True ])
2325@pytest .mark .parametrize ("half_precision_model" , [False , True ])
2426@pytest .mark .parametrize ("half_precision_ops" , [False , True ])
25- def test_re_text_classification (use_auto , half_precision_model , half_precision_ops ):
27+ def test_re_text_classification (use_auto , half_precision_model , half_precision_ops , caplog ):
28+
29+ # set up the pipeline
2630 model_name_or_path = "pie/example-re-textclf-tacred"
2731 if use_auto :
2832 pipeline = AutoPipeline .from_pretrained (
@@ -45,39 +49,53 @@ def test_re_text_classification(use_auto, half_precision_model, half_precision_o
4549 assert pipeline .taskmodule .is_from_pretrained
4650 assert pipeline .model .is_from_pretrained
4751
52+ # create a document with entities
4853 document = ExampleDocument (
4954 "“Making a super tasty alt-chicken wing is only half of it,” said Po Bronson, general partner "
5055 "at SOSV and managing director of IndieBio."
5156 )
57+ document .entities .append (LabeledSpan (start = 65 , end = 75 , label = "PER" ))
58+ document .entities .append (LabeledSpan (start = 96 , end = 100 , label = "ORG" ))
59+ document .entities .append (LabeledSpan (start = 126 , end = 134 , label = "ORG" ))
5260
53- for start , end , label in [(65 , 75 , "PER" ), (96 , 100 , "ORG" ), (126 , 134 , "ORG" )]:
54- document .entities .append (LabeledSpan (start = start , end = end , label = label ))
61+ # predict relations
62+ with caplog .at_level ("WARNING" ):
63+ pipeline (document , batch_size = 2 , half_precision_ops = half_precision_ops )
5564
56- pipeline (document , batch_size = 2 , half_precision_ops = half_precision_ops )
57- relations : Sequence [BinaryRelation ] = document ["relations" ].predictions
58- assert len (relations ) == 3
65+ # sort to get deterministic order
66+ sorted_relations = sorted (document .relations .predictions )
5967
60- rels = sorted (relations , key = lambda rel : (rel .head .start + rel .tail .start ) / 2 )
61-
62- # Note: The scores are quite low, because the model is trained with the old version for the taskmodule,
63- # so the argument markers are not correct.
64- assert (str (rels [0 ].head ), rels [0 ].label , str (rels [0 ].tail )) == (
65- "SOSV" ,
66- "org:top_members/employees" ,
67- "Po Bronson" ,
68- )
69- assert rels [0 ].score == pytest .approx (0.398 , abs = 1e-2 )
68+ # check the relations and their scores
69+ assert [ann .resolve () for ann in sorted_relations ] == [
70+ ("per:employee_of" , (("PER" , "Po Bronson" ), ("ORG" , "IndieBio" ))),
71+ ("org:top_members/employees" , (("ORG" , "SOSV" ), ("PER" , "Po Bronson" ))),
72+ ("org:top_members/employees" , (("ORG" , "IndieBio" ), ("PER" , "Po Bronson" ))),
73+ ]
7074
71- assert (str (rels [1 ].head ), rels [1 ].label , str (rels [1 ].tail )) == (
72- "Po Bronson" ,
73- "per:employee_of" ,
74- "IndieBio" ,
75+ half_precision_warning = (
76+ "Using half precision operations with a model already in half precision. "
77+ "This is not recommended, as it may lead to unexpected results."
7578 )
76- assert rels [1 ].score == pytest .approx (0.534 , abs = 1e-2 )
7779
78- assert (str (rels [2 ].head ), rels [2 ].label , str (rels [2 ].tail )) == (
79- "IndieBio" ,
80- "org:top_members/employees" ,
81- "Po Bronson" ,
82- )
83- assert rels [2 ].score == pytest .approx (0.552 , abs = 1e-2 )
80+ scores = [rel .score for rel in sorted_relations ]
81+ # General note: The scores are quite low, because the model is trained with the old version
82+ # for the taskmodule, so the argument markers are not correct.
83+ # Below scores were obtained with dependencies from poetry.lock on local machine.
84+ if not half_precision_model and not half_precision_ops :
85+ # we low tolerance if no half precision is used
86+ # (i.e., no autocast on forward pass and model is not cast to half precision)
87+ assert scores == pytest .approx (
88+ [0.5339038372039795 , 0.3984701931476593 , 0.5520647764205933 ], abs = 1e-6
89+ )
90+ assert half_precision_warning not in caplog .messages
91+ elif not half_precision_model and half_precision_ops :
92+ # set high tolerance for half precision ops (i.e., autocast on forward pass)
93+ assert scores == pytest .approx ([0.53125 , 0.39453125 , 0.5546875 ], abs = 1e-2 )
94+ assert half_precision_warning not in caplog .messages
95+ elif half_precision_model and not half_precision_ops :
96+ # set high tolerance for half precision model (i.e., model cast to half precision)
97+ assert scores == pytest .approx ([0.53515625 , 0.400390625 , 0.55859375 ], abs = 1e-2 )
98+ assert half_precision_warning not in caplog .messages
99+ else :
100+ # just check that we got the warning about half precision ops in combination with half precision model
101+ assert half_precision_warning in caplog .messages
0 commit comments