diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index 3f80f1dd4..796940f4f 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -82,7 +82,7 @@ def __init__( @brevitas.jit.script_method def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor: - stats = self.parameter_list_stats() + stats = self.parameter_list_stats(x) return self.scale_shift_zero_point(-stats, scale, bit_width) @@ -266,7 +266,7 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor: value = self.scale_shift_zero_point(value, scale, bit_width) return value else: - stats = self.parameter_list_stats() + stats = self.parameter_list_stats(x) # workaround to avoid find_ununsed_parameter=True in DDP stats = stats + 0. * self.value if self.local_loss_mode: