Skip to content

Commit

Permalink
missing fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 24, 2024
1 parent 347a381 commit 3325792
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/brevitas/core/scaling/float_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import brevitas
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_float
from brevitas.utils.quant_utils import MAX_MANTISSA_DICT


class FloatScaling(brevitas.jit.ScriptModule):
Expand All @@ -25,6 +26,7 @@ def __init__(
self.inf_values = inf_values
self.nan_values = nan_values
self.saturating = saturating
self.max_mantissa_dict = MAX_MANTISSA_DICT

if max_available_float:
max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype)
Expand All @@ -36,7 +38,8 @@ def __init__(
def forward(
self, exponent_bit_width: Tensor, mantissa_bit_width: Tensor,
exponent_bias: Tensor) -> Tensor:
max_value = max_float(exponent_bit_width, self.max_mantissa_dict[mantissa_bit_width.item()], exponent_bias)
max_value = max_float(
exponent_bit_width, self.max_mantissa_dict[mantissa_bit_width.item()], exponent_bias)
max_value = max_value if self.max_available_float is None else torch.min(
max_value, self.max_available_float())
return max_value

0 comments on commit 3325792

Please sign in to comment.