Skip to content

Conversation

@RainbowRivey
Copy link
Collaborator

@RainbowRivey RainbowRivey commented Jul 23, 2025

This PR adjusts test_annotation_pipeline in such a way that the current behavior is made transparent: What parameter setting / dependency versions / device causes which result scores? Later (probably breaking) PRs will address anything that needs to be fixed.

In detail, this PR does:

  • adjust test_annotation_pipeline:
    • streamline the test (e.g. use resolve() etc.)
    • create individual test branches with individual expected scores for all combinations of half_precision_ops and half_precision_model
    • decrease absolute tolerance to 1e-6
    • use 10e-2 as absolute tolerance when half_precision_model (reasoning: sing half_precision_model on cpu results in using dtype=torch.bfloat16 which has only 8 significant precision bits, so we use 10e-2 as absolute tolerance)
    • enable torch.use_deterministic_algorithms to make sure results are as reproducible as possible.

In addition, this PR changes the following:

  • show a warning if half_precision_ops is used in combination with half_precision_model, because of recommendation from PyTorch documentation; check just that warning in the respective test case (not the scores anymore!)

Important

Though current half-precision model/ops tests are passing now with lowered tolerance, difference still may exceed this limit when upgrading poetry.lock.
We may need to adjust scores when updating!

Reason for that is bfloat16 type (default for half precision model/ops on CPU) is quite unstable:
It has only 8 precision bits. Results of calculations with it may highly depend on operations order, which leads to different results on different Torch and/or it's back-ends versions and even on different machines.
Also, this tests still may fail locally on some machines due to this.

Requires:

TODO:

  • get tests passing (without lowering the tolerance), locally and at CI
    • use low tolerance (abs=1e-6) for full precision
    • use high tolerance (abs=1e-2) for half precision (both cases: half_precision_model and half_precision_ops)
    • show warning if both are enabled
  • change CI workflow back to only run slow tests
  • adjust this PR description to what it actually implements

Related:

Follow-up:

@RainbowRivey RainbowRivey marked this pull request as draft July 23, 2025 18:20
@RainbowRivey RainbowRivey added the bug Something isn't working label Jul 23, 2025
@RainbowRivey
Copy link
Collaborator Author

RainbowRivey commented Jul 23, 2025

Local tests have shown that:

  • Problem occurs with Torch versions starting from 2.6
  • Only happens with half-precision-model=True, which casts model to bfloat16 when running on CPU
  • No error with half-precision-model=True on GPU, there model is converted to float16.
  • I become different numbers in failing tests when running locally. (Tests on CI fail with prediction score 0.408, locally i get 0.412)

Also not relevant for current bug:

  • minimal version of torch (>=1.10) is not compatible with restriction pytorch-lightning = "^2"
    use torch>=1.11 instead.

@RainbowRivey
Copy link
Collaborator Author

RainbowRivey commented Jul 23, 2025

I could not find any issue or release note for Torch 2.6 directly related to what we experience here.
Possible solutions for now may be:

  • avoid using bfloat16 by changing our get_autocast_dtype() function in pipeline.py to return float16 also for CPU.
  • adjust test to respect new results with bfloat16, which looks like a bad idea, since it gives different results on different machines and with different torch versions.
  • use model.compile(), see UPD below

Also may be a good step:

  • add tests for the model behind this pipeline test. Once we know more precisely what went wrong we could create an issue at Torch, if problem is really on their side. This won't make current versions work correctly, but at least in may get fixed in future versions.

UPD: Calling model.compile() before using the model seems to fix the problem somehow. But torch.compile() is only available since torch 2.0. We could call it only if it is available e.g. if hasattr(torch, "compile"): model.compile(), so we don't need to change minimal Torch version.
But compiling each time makes test even slower. Perhaps we could move pipeline creation into parametrized fixture, so it is called half as often, since it does not depend on 'half_precision_ops' parametrization.

Unrelated:
We also could do the same with get_autocast_dtype(), e.g. if hasattr(torch, "get_autocast_dtype") since it has a # TODO: use torch.get_autocast_dtype when available (pipeline.py#28)

@ArneBinder ArneBinder force-pushed the fix/half_precision_model_tests branch from d2f3dc8 to 6725e89 Compare July 24, 2025 13:52
@ArneBinder ArneBinder changed the title Fix test_annotation_pipeline fails with half-precision-model=True Fix test_annotation_pipeline fails with half-precision-model=True Jul 24, 2025
@RainbowRivey
Copy link
Collaborator Author

use_deterministic_algorithms() removed differences in results for half_precision_ops, both between versions and different machines (local and CI tests)
However for half_precision_model it is still not the case

@ArneBinder
Copy link
Owner

ArneBinder commented Jul 25, 2025

use_deterministic_algorithms() removed differences in results for half_precision_ops, both between versions and different machines (local and CI tests)

Nice! But is there a min torch version that this requires?

@RainbowRivey
Copy link
Collaborator Author

use_deterministic_algorithms() removed differences in results for half_precision_ops, both between versions and different machines (local and CI tests)

Nice! But is there a min torch version that this requires?

Our minimal torch 1.10 already supports it, so should be ok

Anyways, bfloat16 has only 8 significant precision bits, i think we should not try to get it 10e-6 exact, but around 10e-3

@ArneBinder
Copy link
Owner

Anyways, bfloat16 has only 8 significant precision bits, i think we should not try to get it 10e-6 exact, but around 10e-3

what is the reason for using bfloat16 in the first place? I'm asking because above you proposed to use float16 everywhere...

@ArneBinder
Copy link
Owner

ArneBinder commented Jul 25, 2025

hmm strange. With my recent change, only half_precision_ops==True fails for me locally. But now both cases, half_precision_ops==True or the half_precision_model==True fail at CI, although I did not adjust the torch >=2.6 branch at all... I'm a bit clueless, any idea what's happening? @RainbowRivey

EDIT: I was totally assuming, the tests on the CI are running with torch=2.7... is this not the case?
UPD: It uses torch 2.3... see latest test run

@ArneBinder
Copy link
Owner

btw do you have a GPU on your local machine? (I don't)

@RainbowRivey
Copy link
Collaborator Author

Anyways, bfloat16 has only 8 significant precision bits, i think we should not try to get it 10e-6 exact, but around 10e-3

what is the reason for using bfloat16 in the first place? I'm asking because above you proposed to use float16 everywhere...

It is default in pytorch autocast for CPU, so i think we just took it.

btw do you have a GPU on your local machine? (I don't)

Yes, i do! But i double checked tests run on CPU. I also tried tests on GPU, results are also different from CPU, but i think they were stable across versions. I could rerun them to make sure.

@ArneBinder ArneBinder mentioned this pull request Jul 25, 2025
@ArneBinder
Copy link
Owner

@RainbowRivey Can you check one more time what tests pass now locally at your machine (with and w/o GPU)? For me, everything is green now (locally), for both versions (from peotry.lock and with poetry update --with dev, i.e, updated dependencies).

@RainbowRivey
Copy link
Collaborator Author

Here results of local tests

  • cuda implementation of nn.Linear in previos versions had no deterministic version (throws warning and solution), solved in latest versions
  • all cuda results differ slightly from cpu, including full precision
  • two cpu tests still fail in both versions (half_precision_ops=True)

torch=2.3.0 (from poetry.lock)

Details
============================================================================================ FAILURES =============================================================================================
______________________________________________________________________ test_re_text_classification[False-False-False-cuda:0] ______________________________________________________________________
tests/pipeline/test_re_text_classification.py:92: in test_re_text_classification
    assert scores == pytest.approx(
E   assert [0.5339053273...0644783973694] == approx([0.533...33 ± 1.0e-06])
E     
E     comparison failed. Mismatched elements: 1 / 3:
E     Max absolute difference: 1.4901161193847656e-06
E     Max relative difference: 2.790974435232368e-06
E     Index | Obtained           | Expected                    
E     0     | 0.5339053273200989 | 0.5339038372039795 ± 1.0e-06
-------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
---------------------------------------------------------------------------------------- Captured log call ----------------------------------------------------------------------------------------
WARNING  pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING  pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
______________________________________________________________________ test_re_text_classification[False-False-True-cuda:0] _______________________________________________________________________
tests/pipeline/test_re_text_classification.py:92: in test_re_text_classification
    assert scores == pytest.approx(
E   assert [0.5339053273...0644783973694] == approx([0.533...33 ± 1.0e-06])
E     
E     comparison failed. Mismatched elements: 1 / 3:
E     Max absolute difference: 1.4901161193847656e-06
E     Max relative difference: 2.790974435232368e-06
E     Index | Obtained           | Expected                    
E     0     | 0.5339053273200989 | 0.5339038372039795 ± 1.0e-06
-------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
---------------------------------------------------------------------------------------- Captured log call ----------------------------------------------------------------------------------------
WARNING  pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING  pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
WARNING  pytorch_ie.pipeline:pipeline.py:97 Ignoring remaining kwargs: {'binary_output': False}
______________________________________________________________________ test_re_text_classification[False-True-False-cuda:0] _______________________________________________________________________
tests/pipeline/test_re_text_classification.py:67: in test_re_text_classification
    pipeline(document, batch_size=2, half_precision_ops=half_precision_ops)
src/pytorch_ie/pipeline.py:501: in __call__
    output = self.forward(batch, **forward_params)
src/pytorch_ie/pipeline.py:371: in forward
    model_outputs = self._forward(model_inputs, **forward_params)
src/pytorch_ie/pipeline.py:339: in _forward
    return self.model.predict(inputs, **forward_parameters)
src/pytorch_ie/model.py:38: in predict
    outputs = self(inputs, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1541: in _call_impl
    return forward_call(*args, **kwargs)
src/pytorch_ie/models/transformer_text_classification.py:95: in forward
    output = self.model(**inputs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1541: in _call_impl
    return forward_call(*args, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:1150: in forward
    pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1541: in _call_impl
    return forward_call(*args, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:742: in forward
    pooled_output = self.dense(first_token_tensor)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1541: in _call_impl
    return forward_call(*args, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py:116: in forward
    return F.linear(input, self.weight, self.bias)
E   RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
-------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
---------------------------------------------------------------------------------------- Captured log call ----------------------------------------------------------------------------------------
WARNING  pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING  pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
_______________________________________________________________________ test_re_text_classification[False-True-True-cuda:0] _______________________________________________________________________
tests/pipeline/test_re_text_classification.py:67: in test_re_text_classification
    pipeline(document, batch_size=2, half_precision_ops=half_precision_ops)
src/pytorch_ie/pipeline.py:501: in __call__
    output = self.forward(batch, **forward_params)
src/pytorch_ie/pipeline.py:371: in forward
    model_outputs = self._forward(model_inputs, **forward_params)
src/pytorch_ie/pipeline.py:339: in _forward
    return self.model.predict(inputs, **forward_parameters)
src/pytorch_ie/model.py:38: in predict
    outputs = self(inputs, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1541: in _call_impl
    return forward_call(*args, **kwargs)
src/pytorch_ie/models/transformer_text_classification.py:95: in forward
    output = self.model(**inputs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1541: in _call_impl
    return forward_call(*args, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:1150: in forward
    pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1541: in _call_impl
    return forward_call(*args, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py:742: in forward
    pooled_output = self.dense(first_token_tensor)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py:1541: in _call_impl
    return forward_call(*args, **kwargs)
../../.cache/pypoetry/virtualenvs/pytorch-ie-uKFcdlbK-py3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py:116: in forward
    return F.linear(input, self.weight, self.bias)
E   RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
-------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
---------------------------------------------------------------------------------------- Captured log call ----------------------------------------------------------------------------------------
WARNING  pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING  pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
WARNING  pytorch_ie.pipeline:pipeline.py:97 Ignoring remaining kwargs: {'binary_output': False}
________________________________________________________________________ test_re_text_classification[True-False-False-cpu] ________________________________________________________________________
tests/pipeline/test_re_text_classification.py:98: in test_re_text_classification
    assert scores == pytest.approx([0.53125, 0.39453125, 0.5546875], abs=1e-6)
E   assert [0.53125, 0.3...75, 0.5546875] == approx([0.531...75 ± 1.0e-06])
E     
E     comparison failed. Mismatched elements: 1 / 3:
E     Max absolute difference: 0.001953125
E     Max relative difference: 0.0049261083743842365
E     Index | Obtained    | Expected            
E     1     | 0.396484375 | 0.39453125 ± 1.0e-06
-------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
---------------------------------------------------------------------------------------- Captured log call ----------------------------------------------------------------------------------------
WARNING  pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING  pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
______________________________________________________________________ test_re_text_classification[True-False-False-cuda:0] _______________________________________________________________________
tests/pipeline/test_re_text_classification.py:98: in test_re_text_classification
    assert scores == pytest.approx([0.53125, 0.39453125, 0.5546875], abs=1e-6)
E   assert [0.5341796875... 0.5517578125] == approx([0.531...75 ± 1.0e-06])
E     
E     comparison failed. Mismatched elements: 3 / 3:
E     Max absolute difference: 0.00439453125
E     Max relative difference: 0.011015911872705019
E     Index | Obtained      | Expected            
E     0     | 0.5341796875  | 0.53125 ± 1.0e-06   
E     1     | 0.39892578125 | 0.39453125 ± 1.0e-06
E     2     | 0.5517578125  | 0.5546875 ± 1.0e-06
-------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
---------------------------------------------------------------------------------------- Captured log call ----------------------------------------------------------------------------------------
WARNING  pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING  pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
________________________________________________________________________ test_re_text_classification[True-False-True-cpu] _________________________________________________________________________
tests/pipeline/test_re_text_classification.py:98: in test_re_text_classification
    assert scores == pytest.approx([0.53125, 0.39453125, 0.5546875], abs=1e-6)
E   assert [0.53125, 0.3...75, 0.5546875] == approx([0.531...75 ± 1.0e-06])
E     
E     comparison failed. Mismatched elements: 1 / 3:
E     Max absolute difference: 0.001953125
E     Max relative difference: 0.0049261083743842365
E     Index | Obtained    | Expected            
E     1     | 0.396484375 | 0.39453125 ± 1.0e-06
-------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
---------------------------------------------------------------------------------------- Captured log call ----------------------------------------------------------------------------------------
WARNING  pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING  pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
WARNING  pytorch_ie.pipeline:pipeline.py:97 Ignoring remaining kwargs: {'binary_output': False}
_______________________________________________________________________ test_re_text_classification[True-False-True-cuda:0] _______________________________________________________________________
tests/pipeline/test_re_text_classification.py:98: in test_re_text_classification
    assert scores == pytest.approx([0.53125, 0.39453125, 0.5546875], abs=1e-6)
E   assert [0.5341796875... 0.5517578125] == approx([0.531...75 ± 1.0e-06])
E     
E     comparison failed. Mismatched elements: 3 / 3:
E     Max absolute difference: 0.00439453125
E     Max relative difference: 0.011015911872705019
E     Index | Obtained      | Expected            
E     0     | 0.5341796875  | 0.53125 ± 1.0e-06   
E     1     | 0.39892578125 | 0.39453125 ± 1.0e-06
E     2     | 0.5517578125  | 0.5546875 ± 1.0e-06
-------------------------------------------------------------------------------------- Captured stderr call ---------------------------------------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
---------------------------------------------------------------------------------------- Captured log call ----------------------------------------------------------------------------------------
WARNING  pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING  pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
WARNING  pytorch_ie.pipeline:pipeline.py:97 Ignoring remaining kwargs: {'binary_output': False}

or shortly:

======================================================== short test summary info =========================================================
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[False-False-False-cuda:0] - assert [0.5339053273...0644783973694] == approx([0.533...33 ± 1.0e-06])
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[False-False-True-cuda:0] - assert [0.5339053273...0644783973694] == approx([0.533...33 ± 1.0e-06])
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[False-True-False-cuda:0] - RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, but this operation is not determin...
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[False-True-True-cuda:0] - RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, but this operation is not determin...
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[True-False-False-cpu] - assert [0.53125, 0.3...75, 0.5546875] == approx([0.531...75 ± 1.0e-06])
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[True-False-False-cuda:0] - assert [0.5341796875... 0.5517578125] == approx([0.531...75 ± 1.0e-06])
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[True-False-True-cpu] - assert [0.53125, 0.3...75, 0.5546875] == approx([0.531...75 ± 1.0e-06])
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[True-False-True-cuda:0] - assert [0.5341796875... 0.5517578125] == approx([0.531...75 ± 1.0e-06])

and with torch==2.7.1 (all latest from pyproject.toml)

Details

================================================================ FAILURES ================================================================
_________________________________________ test_re_text_classification[False-False-False-cuda:0] __________________________________________
tests/pipeline/test_re_text_classification.py:92: in test_re_text_classification
assert scores == pytest.approx(
E assert [0.5339053273...0644783973694] == approx([0.533...33 ± 1.0e-06])
E
E comparison failed. Mismatched elements: 1 / 3:
E Max absolute difference: 1.4901161193847656e-06
E Max relative difference: 2.790974435232368e-06
E Index | Obtained | Expected
E 0 | 0.5339053273200989 | 0.5339038372039795 ± 1.0e-06
---------------------------------------------------------- Captured stderr call ----------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
----------------------------------------------------------- Captured log call ------------------------------------------------------------
WARNING pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
__________________________________________ test_re_text_classification[False-False-True-cuda:0] __________________________________________
tests/pipeline/test_re_text_classification.py:92: in test_re_text_classification
assert scores == pytest.approx(
E assert [0.5339053273...0644783973694] == approx([0.533...33 ± 1.0e-06])
E
E comparison failed. Mismatched elements: 1 / 3:
E Max absolute difference: 1.4901161193847656e-06
E Max relative difference: 2.790974435232368e-06
E Index | Obtained | Expected
E 0 | 0.5339053273200989 | 0.5339038372039795 ± 1.0e-06
---------------------------------------------------------- Captured stderr call ----------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
----------------------------------------------------------- Captured log call ------------------------------------------------------------
WARNING pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
WARNING pytorch_ie.pipeline:pipeline.py:97 Ignoring remaining kwargs: {'binary_output': False}
__________________________________________ test_re_text_classification[False-True-False-cuda:0] __________________________________________
tests/pipeline/test_re_text_classification.py:109: in test_re_text_classification
assert scores == pytest.approx([0.53125, 0.412109375, 0.55859375], abs=1e-2)
E assert [0.5341796875...0.55322265625] == approx([0.531...59375 ± 0.01])
E
E comparison failed. Mismatched elements: 1 / 3:
E Max absolute difference: 0.0126953125
E Max relative difference: 0.03178484107579462
E Index | Obtained | Expected
E 1 | 0.3994140625 | 0.412109375 ± 0.01
---------------------------------------------------------- Captured stderr call ----------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
----------------------------------------------------------- Captured log call ------------------------------------------------------------
WARNING pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
__________________________________________ test_re_text_classification[False-True-True-cuda:0] ___________________________________________
tests/pipeline/test_re_text_classification.py:109: in test_re_text_classification
assert scores == pytest.approx([0.53125, 0.412109375, 0.55859375], abs=1e-2)
E assert [0.5341796875...0.55322265625] == approx([0.531...59375 ± 0.01])
E
E comparison failed. Mismatched elements: 1 / 3:
E Max absolute difference: 0.0126953125
E Max relative difference: 0.03178484107579462
E Index | Obtained | Expected
E 1 | 0.3994140625 | 0.412109375 ± 0.01
---------------------------------------------------------- Captured stderr call ----------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
----------------------------------------------------------- Captured log call ------------------------------------------------------------
WARNING pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
WARNING pytorch_ie.pipeline:pipeline.py:97 Ignoring remaining kwargs: {'binary_output': False}
___________________________________________ test_re_text_classification[True-False-False-cpu] ____________________________________________
tests/pipeline/test_re_text_classification.py:100: in test_re_text_classification
assert scores == pytest.approx([0.53125, 0.396484375, 0.55078125], abs=1e-6)
E assert [0.53125, 0.3...75, 0.5546875] == approx([0.531...25 ± 1.0e-06])
E
E comparison failed. Mismatched elements: 1 / 3:
E Max absolute difference: 0.00390625
E Max relative difference: 0.007042253521126761
E Index | Obtained | Expected
E 2 | 0.5546875 | 0.55078125 ± 1.0e-06
---------------------------------------------------------- Captured stderr call ----------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
----------------------------------------------------------- Captured log call ------------------------------------------------------------
WARNING pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
__________________________________________ test_re_text_classification[True-False-False-cuda:0] __________________________________________
tests/pipeline/test_re_text_classification.py:100: in test_re_text_classification
assert scores == pytest.approx([0.53125, 0.396484375, 0.55078125], abs=1e-6)
E assert [0.533203125,... 0.5517578125] == approx([0.531...25 ± 1.0e-06])
E
E comparison failed. Mismatched elements: 3 / 3:
E Max absolute difference: 0.00244140625
E Max relative difference: 0.006119951040391677
E Index | Obtained | Expected
E 0 | 0.533203125 | 0.53125 ± 1.0e-06
E 1 | 0.39892578125 | 0.396484375 ± 1.0e-06
E 2 | 0.5517578125 | 0.55078125 ± 1.0e-06
---------------------------------------------------------- Captured stderr call ----------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
----------------------------------------------------------- Captured log call ------------------------------------------------------------
WARNING pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
____________________________________________ test_re_text_classification[True-False-True-cpu] ____________________________________________
tests/pipeline/test_re_text_classification.py:100: in test_re_text_classification
assert scores == pytest.approx([0.53125, 0.396484375, 0.55078125], abs=1e-6)
E assert [0.53125, 0.3...75, 0.5546875] == approx([0.531...25 ± 1.0e-06])
E
E comparison failed. Mismatched elements: 1 / 3:
E Max absolute difference: 0.00390625
E Max relative difference: 0.007042253521126761
E Index | Obtained | Expected
E 2 | 0.5546875 | 0.55078125 ± 1.0e-06
---------------------------------------------------------- Captured stderr call ----------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
----------------------------------------------------------- Captured log call ------------------------------------------------------------
WARNING pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
WARNING pytorch_ie.pipeline:pipeline.py:97 Ignoring remaining kwargs: {'binary_output': False}
__________________________________________ test_re_text_classification[True-False-True-cuda:0] ___________________________________________
tests/pipeline/test_re_text_classification.py:100: in test_re_text_classification
assert scores == pytest.approx([0.53125, 0.396484375, 0.55078125], abs=1e-6)
E assert [0.533203125,... 0.5517578125] == approx([0.531...25 ± 1.0e-06])
E
E comparison failed. Mismatched elements: 3 / 3:
E Max absolute difference: 0.00244140625
E Max relative difference: 0.006119951040391677
E Index | Obtained | Expected
E 0 | 0.533203125 | 0.53125 ± 1.0e-06
E 1 | 0.39892578125 | 0.396484375 ± 1.0e-06
E 2 | 0.5517578125 | 0.55078125 ± 1.0e-06
---------------------------------------------------------- Captured stderr call ----------------------------------------------------------
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the __call__ method is faster than using a method to encode the text followed by a call to the pad method to get a padded encoding.
----------------------------------------------------------- Captured log call ------------------------------------------------------------
WARNING pytorch_ie.taskmodules.transformer_re_text_classification:transformer_re_text_classification.py:179 The parameter entity_annotation is deprecated and will be discarded because it is not necessary anymore. The target of the relation layer already specifies the entity layer.
WARNING pytorch_ie.models.transformer_text_classification:transformer_text_classification.py:49 t_total is deprecated, we use estimated_stepping_batches from the pytorch lightning trainer instead
WARNING pytorch_ie.pipeline:pipeline.py:97 Ignoring remaining kwargs: {'binary_output': False}

======================================================== short test summary info =========================================================
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[False-False-False-cuda:0] - assert [0.5339053273...0644783973694] == approx([0.533...33 ± 1.0e-06])
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[False-False-True-cuda:0] - assert [0.5339053273...0644783973694] == approx([0.533...33 ± 1.0e-06])
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[False-True-False-cuda:0] - assert [0.5341796875...0.55322265625] == approx([0.531...59375 ± 0.01])
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[False-True-True-cuda:0] - assert [0.5341796875...0.55322265625] == approx([0.531...59375 ± 0.01])
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[True-False-False-cpu] - assert [0.53125, 0.3...75, 0.5546875] == approx([0.531...25 ± 1.0e-06])
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[True-False-False-cuda:0] - assert [0.533203125,... 0.5517578125] == approx([0.531...25 ± 1.0e-06])
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[True-False-True-cpu] - assert [0.53125, 0.3...75, 0.5546875] == approx([0.531...25 ± 1.0e-06])
FAILED tests/pipeline/test_re_text_classification.py::test_re_text_classification[True-False-True-cuda:0] - assert [0.533203125,... 0.5517578125] == approx([0.531...25 ± 1.0e-06])

@RainbowRivey RainbowRivey mentioned this pull request Jul 28, 2025
13 tasks
@ArneBinder ArneBinder force-pushed the fix/half_precision_model_tests branch from b27e0f9 to b4c5714 Compare August 2, 2025 14:44
@codecov
Copy link

codecov bot commented Aug 2, 2025

Codecov Report

❌ Patch coverage is 0% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.10%. Comparing base (64335d9) to head (a0bf4db).
⚠️ Report is 29 commits behind head on main.

Files with missing lines Patch % Lines
src/pytorch_ie/pipeline.py 0.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #477      +/-   ##
==========================================
- Coverage   72.20%   72.10%   -0.10%     
==========================================
  Files          32       32              
  Lines        2173     2176       +3     
  Branches      316      318       +2     
==========================================
  Hits         1569     1569              
- Misses        523      526       +3     
  Partials       81       81              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ArneBinder ArneBinder marked this pull request as ready for review August 2, 2025 15:58
@RainbowRivey RainbowRivey merged commit 6d5a064 into main Aug 4, 2025
5 of 6 checks passed
@RainbowRivey RainbowRivey deleted the fix/half_precision_model_tests branch August 4, 2025 15:27
ArneBinder added a commit that referenced this pull request Aug 4, 2025
This is the last non-breaking piece of the `pie-core` refactor. See
[this](ArneBinder/pie-core#17) for context.

requires: 
 - [x] #477
 - [x] [pie-core
0.2.1](https://github.com/ArneBinder/pie-core/releases/tag/v0.2.1)
because of ArneBinder/pie-core#80
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants