Skip to content

Commit

Permalink
minifloat fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 20, 2024
1 parent f7aae4d commit 0bcd6ca
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions src/brevitas/proxy/float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,49 +24,73 @@ def bit_width(self):
def scale(self):
if not self.is_quant_enabled:
return None
scale = self.__call__(self.tracked_parameter_list[0]).scale
elif self._cached_weight:
scale = self._cached_weight.scale
else:
scale = self.__call__(self.tracked_parameter_list[0]).scale
return scale

def zero_point(self):
if not self.is_quant_enabled:
return None
zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point
elif self._cached_weight:
zero_point = self._cached_weight.zero_point
else:
zero_point = self.__call__(self.tracked_parameter_list[0]).zero_point
return zero_point

def exponent_bit_width(self):
if not self.is_quant_enabled:
return None
exponent_bit_width = self.__call__(self.tracked_parameter_list[0]).exponent_bit_width
elif self._cached_weight:
exponent_bit_width = self._cached_weight.exponent_bit_width
else:
exponent_bit_width = self.__call__(self.tracked_parameter_list[0]).exponent_bit_width
return exponent_bit_width

def mantissa_bit_width(self):
if not self.is_quant_enabled:
return None
mantissa_bit_width = self.__call__(self.tracked_parameter_list[0]).mantissa_bit_width
elif self._cached_weight:
mantissa_bit_width = self._cached_weight.mantissa_bit_width
else:
mantissa_bit_width = self.__call__(self.tracked_parameter_list[0]).mantissa_bit_width
return mantissa_bit_width

def exponent_bias(self):
if not self.is_quant_enabled:
return None
exponent_bias = self.__call__(self.tracked_parameter_list[0]).exponent_bias
elif self._cached_weight:
exponent_bias = self._cached_weight.exponent_bias
else:
exponent_bias = self.__call__(self.tracked_parameter_list[0]).exponent_bias
return exponent_bias

def is_saturating(self):
if not self.is_quant_enabled:
return None
saturating = self.__call__(self.tracked_parameter_list[0]).saturating
elif self._cached_weight:
saturating = self._cached_weight.saturating
else:
saturating = self.__call__(self.tracked_parameter_list[0]).saturating
return saturating

def inf_values(self):
if not self.is_quant_enabled:
return None
inf_values = self.__call__(self.tracked_parameter_list[0]).inf_values
elif self._cached_weight:
inf_values = self._cached_weight.inf_values
else:
inf_values = self.__call__(self.tracked_parameter_list[0]).inf_values
return inf_values

def nan_values(self):
if not self.is_quant_enabled:
return None
nan_values = self.__call__(self.tracked_parameter_list[0]).nan_values
elif self._cached_weight:
nan_values = self._cached_weight.nan_values
else:
nan_values = self.__call__(self.tracked_parameter_list[0]).nan_values
return nan_values

@property
Expand Down

0 comments on commit 0bcd6ca

Please sign in to comment.