Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 54 additions & 22 deletions src/finn/custom_op/fpgadataflow/elementwise_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,19 +346,35 @@ def minimize_weight_bit_width(self, model: ModelWrapper):
# Remember the "style" of receiving the input for further code
# generation
self.set_nodeattr("lhs_style", "const")
# Minimum and maximum "weight" on the left hand side, determining
# the range of values which needs to be represented
_min = lhs.min()
_max = lhs.max()
# Determine whether signed or unsigned type is required for
# representing the weights and select the largest "signed magnitude"
_mag = _max if _min > 0 else _min if (abs(_min) > _max) else (-_max - 1)
# Smallest data type large enough to represent this range of values
dtype = DataType.get_smallest_possible(_mag)
lhs_dtype = self.get_input_datatype(0)
# ignore minimization for floats
if not lhs_dtype.get_canonical_name().startswith("FLOAT"):
if lhs_dtype.is_integer():
# Minimum and maximum "weight" on the left hand side, determining
# the range of values which needs to be represented
_min = lhs.min()
_max = lhs.max()
# Determine whether signed or unsigned type is required for
# representing the weights and select the largest "signed magnitude"
_mag = _max if _min > 0 else _min if (abs(_min) > _max) else (-_max - 1)
# Smallest data type large enough to represent this range of values
lhs_dtype = DataType.get_smallest_possible(_mag)
elif lhs_dtype.is_fixed_point():
# Convert the fixed-point array to corresponding integers and get
# smallest integer representation
lhs = lhs / lhs_dtype.scale_factor()
_min = lhs.min()
_max = lhs.max()
_mag = _max if _min > 0 else _min if (abs(_min) > _max) else (-_max - 1)
dtype = DataType.get_smallest_possible(_mag)
_total_bits = dtype.bitwidth() if dtype.signed() else dtype.bitwidth() + 1
_integer_bits = _total_bits - lhs_dtype.frac_bits()
lhs_dtype = DataType[f"FIXED<{_total_bits},{_integer_bits}>"]

# Update the corresponding data type attribute of the node
self.set_nodeattr("lhs_dtype", dtype.name)
self.set_nodeattr("lhs_dtype", lhs_dtype.name)
# Annotate the tensor with the new data type
model.set_tensor_datatype(self.onnx_node.input[0], dtype)
model.set_tensor_datatype(self.onnx_node.input[0], lhs_dtype)

# Check for an initializer providing the right hand side input
rhs = model.get_initializer(self.onnx_node.input[1])
Expand All @@ -368,19 +384,35 @@ def minimize_weight_bit_width(self, model: ModelWrapper):
# Remember the "style" of receiving the input for further code
# generation
self.set_nodeattr("rhs_style", "const")
# Minimum and maximum "weight" on the right hand side, determining
# the range of values which needs to be represented
_min = rhs.min()
_max = rhs.max()
# Determine whether signed or unsigned type is required for
# representing the weights and select the largest "signed magnitude"
_mag = _max if _min > 0 else _min if (abs(_min) > _max) else (-_max - 1)
# Smallest data type large enough to represent this range of values
dtype = DataType.get_smallest_possible(_mag)
rhs_dtype = self.get_input_datatype(1)
# ignore minimization for floats
if not rhs_dtype.get_canonical_name().startswith("FLOAT"):
if rhs_dtype.is_integer():
# Minimum and maximum "weight" on the left hand side, determining
# the range of values which needs to be represented
_min = rhs.min()
_max = rhs.max()
# Determine whether signed or unsigned type is required for
# representing the weights and select the largest "signed magnitude"
_mag = _max if _min > 0 else _min if (abs(_min) > _max) else (-_max - 1)
# Smallest data type large enough to represent this range of values
rhs_dtype = DataType.get_smallest_possible(_mag)
elif rhs_dtype.is_fixed_point():
# Convert the fixed-point array to corresponding integers and get
# smallest integer representation
rhs = rhs / rhs_dtype.scale_factor()
_min = rhs.min()
_max = rhs.max()
_mag = _max if _min > 0 else _min if (abs(_min) > _max) else (-_max - 1)
dtype = DataType.get_smallest_possible(_mag)
_total_bits = dtype.bitwidth() if dtype.signed() else dtype.bitwidth() + 1
_integer_bits = _total_bits - rhs_dtype.frac_bits()
rhs_dtype = DataType[f"FIXED<{_total_bits},{_integer_bits}>"]

# Update the corresponding data type attribute of the node
self.set_nodeattr("rhs_dtype", dtype.name)
self.set_nodeattr("rhs_dtype", rhs_dtype.name)
# Annotate the tensor with the new data type
model.set_tensor_datatype(self.onnx_node.input[1], dtype)
model.set_tensor_datatype(self.onnx_node.input[1], rhs_dtype)

# TODO: MVAU returns the data type here, which does not make sense for
# potentially two data types changing and apparently, the
Expand Down
30 changes: 23 additions & 7 deletions tests/fpgadataflow/test_fpgadataflow_elementwise_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,13 @@ def create_elementwise_binary_operation_onnx(
# Data type of the left-hand-side and right-hand-side input elements
@pytest.mark.parametrize(
"lhs_dtype_rhs_dtype",
[("INT8", "INT8"), ("INT8", "FLOAT32"), ("FLOAT32", "FLOAT32"), ("FLOAT16", "FLOAT16")],
[
("INT8", "INT8"),
("INT8", "FLOAT32"),
("FLOAT32", "FLOAT32"),
("FLOAT16", "FLOAT16"),
("FIXED<8,4>", "FIXED<10,5>"),
],
)
# Shape of the left-hand-side input
@pytest.mark.parametrize("lhs_shape", [[3, 1, 7, 1], [1]])
Expand All @@ -156,12 +162,16 @@ def test_elementwise_binary_operation(
op_type, lhs_dtype_rhs_dtype, lhs_shape, rhs_shape, pe, initializers, exec_mode
):
lhs_dtype, rhs_dtype = lhs_dtype_rhs_dtype
if "Bitwise" in op_type and (lhs_dtype.startswith("FLOAT") or rhs_dtype.startswith("FLOAT")):
pytest.skip("Float datatypes are not meaningful for bitwise ops, skipping those tests.")
if op_type in ["ElementwiseAnd", "ElementwiseOr", "ElementwiseXor"] and (
lhs_dtype.startswith("FLOAT") or rhs_dtype.startswith("FLOAT")
if "Bitwise" in op_type and not ("INT" in lhs_dtype and "INT" in rhs_dtype):
pytest.skip(
"Non-integer datatypes are not meaningful for bitwise ops, skipping those tests."
)
if op_type in ["ElementwiseAnd", "ElementwiseOr", "ElementwiseXor"] and not (
"INT" in lhs_dtype and "INT" in rhs_dtype
):
pytest.skip("Float datatypes are not meaningful for logical ops, skipping those tests.")
pytest.skip(
"Non-integer datatypes are not meaningful for logical ops, skipping those tests."
)
out_dtype = "FLOAT16" if lhs_dtype == "FLOAT16" and rhs_dtype == "FLOAT16" else "FLOAT32"
# Make dummy model for testing
model = create_elementwise_binary_operation_onnx(
Expand Down Expand Up @@ -248,7 +258,13 @@ def test_elementwise_binary_operation(
# Data type of the left-hand-side and right-hand-side input elements
@pytest.mark.parametrize(
"lhs_dtype_rhs_dtype",
[("INT8", "INT8"), ("INT8", "FLOAT32"), ("FLOAT32", "FLOAT32"), ("FLOAT16", "FLOAT16")],
[
("INT8", "INT8"),
("INT8", "FLOAT32"),
("FLOAT32", "FLOAT32"),
("FLOAT16", "FLOAT16"),
("FIXED<8,4>", "FIXED<10,5>"),
],
)
# Shape of the left-hand-side input
@pytest.mark.parametrize("lhs_shape", [[3, 1, 7, 1]])
Expand Down