From 346c1552600fe0ab96c112e1979cf83b605e0a1f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 24 Sep 2024 11:06:11 +0100 Subject: [PATCH] Cleaner test eval --- tests/brevitas/graph/test_calibration.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 1da67ff3a..f48043558 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -102,9 +102,8 @@ def forward(self, x): computed_scale = model.act.act_quant.scale(), model.act_1.act_quant.scale() reference_values = REFERENCE_SCALES[reference] - assert all([ - torch.allclose(comp, torch.tensor(ref)) for comp, - ref in zip(computed_scale, reference_values)]) + assert torch.allclose(computed_scale[0], torch.tensor(reference_values[0])) + assert torch.allclose(computed_scale[1], torch.tensor(reference_values[1])) def test_calibration_training_state():