diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 95d6cc5d9..b22994275 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -28,7 +28,9 @@ REFERENCE_SCALES = { 'int_quant': (0.00935234408825635910, 0.01362917013466358185), 'fp_quant': (0.00249395845457911491, 0.00363444536924362183)} -REFERNECE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) +REFERENCE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]]) +REFERENCE_WEIGHTS = torch.tensor([[1.0023, 0.0205, 1.4604], [-0.2918, -1.8218, -0.7010], + [1.4573, -0.9074, -0.2708]]) def compute_quantile(x, q): @@ -86,9 +88,7 @@ class TestModel(nn.Module): def __init__(self): super(TestModel, self).__init__() self.act = qnn.QuantReLU(act_quant=act_quant) - self.linear_weights = torch.tensor([[1.0023, 0.0205, - 1.4604], [-0.2918, -1.8218, -0.7010], - [1.4573, -0.9074, -0.2708]]) + self.linear_weights = REFERENCE_WEIGHTS self.act_1 = qnn.QuantIdentity(act_quant=act_quant) def forward(self, x): @@ -97,7 +97,7 @@ def forward(self, x): return self.act_1(o) # Reference input - inp = REFERNECE_INP + inp = REFERENCE_INP model = TestModel() model.eval() with torch.no_grad():