Skip to content

Commit 6d5a064

Browse files
authored
Merge pull request #477 from ArneBinder/fix/half_precision_model_tests
Fix `test_annotation_pipeline` fails with half-precision-model=True - adjust `test_annotation_pipeline`: - streamline the test (e.g. use `resolve()` etc.) - create individual test branches with individual expected scores for all combinations of `half_precision_ops` and `half_precision_model` - decrease absolute tolerance to `1e-6` - use `10e-2` as absolute tolerance when `half_precision_model` (reasoning: sing half_precision_model on cpu results in using dtype=torch.bfloat16 which has only 8 significant precision bits, so we use 10e-2 as absolute tolerance) - enable `torch.use_deterministic_algorithms` to make sure results are as reproducible as possible.
2 parents 64335d9 + a0bf4db commit 6d5a064

File tree

2 files changed

+57
-28
lines changed

2 files changed

+57
-28
lines changed

src/pytorch_ie/pipeline.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,17 @@ def __call__(
481481

482482
show_progress_bar = forward_params.pop("show_progress_bar", False)
483483
half_precision_ops = forward_params.pop("half_precision_ops", False)
484+
485+
# Torch documentation recommends: "When entering an autocast-enabled region, Tensors may be any type.
486+
# You should not call half() or bfloat16() on your model(s) or inputs when using autocasting."
487+
# (see https://docs.pytorch.org/docs/stable/amp.html#torch.autocast). So show a warning in this case.
488+
if half_precision_ops:
489+
if self.model.dtype == get_autocast_dtype(self.device.type):
490+
logger.warning(
491+
"Using half precision operations with a model already in half precision. "
492+
"This is not recommended, as it may lead to unexpected results."
493+
)
494+
484495
model_outputs: List = []
485496
with torch.no_grad():
486497
with torch.autocast(device_type=self.device.type, enabled=half_precision_ops):
Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass
2-
from typing import Sequence
32

43
import pytest
4+
import torch
55
from pie_core import AnnotationLayer, annotation_field
66

77
from pytorch_ie import AutoPipeline
@@ -11,6 +11,8 @@
1111
from pytorch_ie.pipeline import Pipeline
1212
from pytorch_ie.taskmodules import TransformerRETextClassificationTaskModule
1313

14+
torch.use_deterministic_algorithms(True)
15+
1416

1517
@dataclass
1618
class ExampleDocument(TextDocument):
@@ -22,7 +24,9 @@ class ExampleDocument(TextDocument):
2224
@pytest.mark.parametrize("use_auto", [False, True])
2325
@pytest.mark.parametrize("half_precision_model", [False, True])
2426
@pytest.mark.parametrize("half_precision_ops", [False, True])
25-
def test_re_text_classification(use_auto, half_precision_model, half_precision_ops):
27+
def test_re_text_classification(use_auto, half_precision_model, half_precision_ops, caplog):
28+
29+
# set up the pipeline
2630
model_name_or_path = "pie/example-re-textclf-tacred"
2731
if use_auto:
2832
pipeline = AutoPipeline.from_pretrained(
@@ -45,39 +49,53 @@ def test_re_text_classification(use_auto, half_precision_model, half_precision_o
4549
assert pipeline.taskmodule.is_from_pretrained
4650
assert pipeline.model.is_from_pretrained
4751

52+
# create a document with entities
4853
document = ExampleDocument(
4954
"“Making a super tasty alt-chicken wing is only half of it,” said Po Bronson, general partner "
5055
"at SOSV and managing director of IndieBio."
5156
)
57+
document.entities.append(LabeledSpan(start=65, end=75, label="PER"))
58+
document.entities.append(LabeledSpan(start=96, end=100, label="ORG"))
59+
document.entities.append(LabeledSpan(start=126, end=134, label="ORG"))
5260

53-
for start, end, label in [(65, 75, "PER"), (96, 100, "ORG"), (126, 134, "ORG")]:
54-
document.entities.append(LabeledSpan(start=start, end=end, label=label))
61+
# predict relations
62+
with caplog.at_level("WARNING"):
63+
pipeline(document, batch_size=2, half_precision_ops=half_precision_ops)
5564

56-
pipeline(document, batch_size=2, half_precision_ops=half_precision_ops)
57-
relations: Sequence[BinaryRelation] = document["relations"].predictions
58-
assert len(relations) == 3
65+
# sort to get deterministic order
66+
sorted_relations = sorted(document.relations.predictions)
5967

60-
rels = sorted(relations, key=lambda rel: (rel.head.start + rel.tail.start) / 2)
61-
62-
# Note: The scores are quite low, because the model is trained with the old version for the taskmodule,
63-
# so the argument markers are not correct.
64-
assert (str(rels[0].head), rels[0].label, str(rels[0].tail)) == (
65-
"SOSV",
66-
"org:top_members/employees",
67-
"Po Bronson",
68-
)
69-
assert rels[0].score == pytest.approx(0.398, abs=1e-2)
68+
# check the relations and their scores
69+
assert [ann.resolve() for ann in sorted_relations] == [
70+
("per:employee_of", (("PER", "Po Bronson"), ("ORG", "IndieBio"))),
71+
("org:top_members/employees", (("ORG", "SOSV"), ("PER", "Po Bronson"))),
72+
("org:top_members/employees", (("ORG", "IndieBio"), ("PER", "Po Bronson"))),
73+
]
7074

71-
assert (str(rels[1].head), rels[1].label, str(rels[1].tail)) == (
72-
"Po Bronson",
73-
"per:employee_of",
74-
"IndieBio",
75+
half_precision_warning = (
76+
"Using half precision operations with a model already in half precision. "
77+
"This is not recommended, as it may lead to unexpected results."
7578
)
76-
assert rels[1].score == pytest.approx(0.534, abs=1e-2)
7779

78-
assert (str(rels[2].head), rels[2].label, str(rels[2].tail)) == (
79-
"IndieBio",
80-
"org:top_members/employees",
81-
"Po Bronson",
82-
)
83-
assert rels[2].score == pytest.approx(0.552, abs=1e-2)
80+
scores = [rel.score for rel in sorted_relations]
81+
# General note: The scores are quite low, because the model is trained with the old version
82+
# for the taskmodule, so the argument markers are not correct.
83+
# Below scores were obtained with dependencies from poetry.lock on local machine.
84+
if not half_precision_model and not half_precision_ops:
85+
# we low tolerance if no half precision is used
86+
# (i.e., no autocast on forward pass and model is not cast to half precision)
87+
assert scores == pytest.approx(
88+
[0.5339038372039795, 0.3984701931476593, 0.5520647764205933], abs=1e-6
89+
)
90+
assert half_precision_warning not in caplog.messages
91+
elif not half_precision_model and half_precision_ops:
92+
# set high tolerance for half precision ops (i.e., autocast on forward pass)
93+
assert scores == pytest.approx([0.53125, 0.39453125, 0.5546875], abs=1e-2)
94+
assert half_precision_warning not in caplog.messages
95+
elif half_precision_model and not half_precision_ops:
96+
# set high tolerance for half precision model (i.e., model cast to half precision)
97+
assert scores == pytest.approx([0.53515625, 0.400390625, 0.55859375], abs=1e-2)
98+
assert half_precision_warning not in caplog.messages
99+
else:
100+
# just check that we got the warning about half precision ops in combination with half precision model
101+
assert half_precision_warning in caplog.messages

0 commit comments

Comments
 (0)