Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 131 additions & 48 deletions src/finn/transformation/streamline/absorb.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,73 +31,156 @@
import warnings
from onnx import helper as oh
from qonnx.core.datatype import DataType
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.base import Transformation
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.util.basic import get_by_name


# Note: Old name kept for compatibility reasons but actually allows to absorb
# any bias irrespective of signedness which might result in changed signedness
# of the output type
class AbsorbSignBiasIntoMultiThreshold(Transformation):
"""Absorb scalar bias originating from signed int export back into
MultiThreshold and re-evaluate the output datatype."""

def apply(self, model):
def apply(self, model: ModelWrapper):
# Get the model graph out of the model wrapper object
graph = model.graph
node_ind = 0
# Keep track of whether the graph has been modified
graph_modified = False
for n in graph.node:
# search for (MultiThreshold, Add) pair
node_ind += 1
# Iterate all nodes in the graph keeping track of the index
for index, node in enumerate(graph.node):
# Only non-branching threshold operations are supported
if (
n.op_type == "MultiThreshold"
and not model.is_fork_node(n)
and not model.is_join_node(n)
node.op_type == "MultiThreshold"
and not model.is_fork_node(node)
and not model.is_join_node(node)
):
consumer = model.find_consumer(n.output[0])
# We now we are not forking, so there is at most one consumer
consumer = model.find_consumer(node.output[0])
# At the end of the graph we might have no consumer. If we have
# one, only handle Adds, turn Sub into Add first...
if consumer is not None and consumer.op_type == "Add":
mt_node = n
add_node = consumer
threshold_name = mt_node.input[1]
add_weight_name = add_node.input[1]
T = model.get_initializer(threshold_name)
A = model.get_initializer(add_weight_name)
if (A is None) or (T is None):
warnings.warn("Threshold or add bias not constant, skipping")
# Try to get the parameter tensor for the addition: Sanity
# check whether this is present, even though we already
# tested for non-joining
bias = model.get_initializer(consumer.input[1])

# Warn and skip if there is no constant bias present
if bias is None:
warnings.warn(
f"{self.__class__.__name__}: Bias not constant for"
f" {consumer.name}, skipping."
)
# Skip to next node, nothing changed so far, no need to
# break here
continue
end_name = add_node.output[0]
# we can only absorb scalar adds
is_scalar = A.ndim == 0 or all(x == 1 for x in A.shape)
if not is_scalar:

# Try to get the parameter tensor for the thresholds: Sanity
# check whether this is present, even though we already
# tested for non-joining
thresholds = model.get_initializer(node.input[1])

# Warn and skip if there is no constant bias present
if thresholds is None:
warnings.warn(
f"{self.__class__.__name__}: Thresholds not"
f" constant for {node.name}, skipping."
)
# Skip to next node, nothing changed so far, no need to
# break here
continue

# Check whether the bias is as scalar as we cannot absorb
# full tensors into node attributes
if not (bias.ndim == 0 or all(x == 1 for x in bias.shape)):
warnings.warn(
f"{self.__class__.__name__}: Bias not scalar"
f" for {consumer.name}, skipping."
)
# Skip to next node, nothing changed so far, no need to
# break here
continue

# CustomOp instance of the thresholding node required for
# convenient attribute manipulation
threshold_op = getCustomOp(node)
# Remember the old datatype for some further checks and info
old_odt = threshold_op.get_nodeattr("out_dtype")
# Get the number of bits currently used to represent the
# output values
bits = DataType[old_odt].bitwidth() # noqa: bitwidth?
# Check whether these thresholds have been generated from a
# narrow range quantizer
narrow = int(thresholds.shape[-1] < (2**bits - 1))

# Flatten effectively scalar bias tensors and extract to
# have "plain" scalar
bias = bias.flatten()[0]
# Shift the output bias of the thresholding operator
out_bias = threshold_op.get_nodeattr("out_bias") + bias
# Derive the new output range due to shifting the bias
# Note: We count thresholds steps on top of the bias
new_min = out_bias - narrow
new_max = out_bias + thresholds.shape[-1]

# Allows the signedness to change depending on the new
# output range [new_min,new_max]
if abs(new_min) >= abs(new_max):
odt = DataType.get_smallest_possible(new_min)
else:
odt = DataType.get_smallest_possible(new_max)

# Check whether the new range can be represented with the
# derived integer datatype
if not (odt.allowed(new_max) and odt.allowed(new_min)):
# Cannot be represented, warn and skip transforming
warnings.warn(
f"{self.__class__.__name__}: Cannot absorb bias"
f" from {consumer.name} into {node.name}: {bias}"
)
# Skip to the next candidate node
continue
bias = A.flatten()[0]
# set MultiThreshold bias property
mt_inst = getCustomOp(mt_node)
bias += mt_inst.get_nodeattr("out_bias")
mt_inst.set_nodeattr("out_bias", bias)

# Check whether the datatype changes as this is something
# the "user" should be aware of
if odt.name != old_odt:
warnings.warn(
f"{self.__class__.__name__}: Output datatype for"
f" {node.name} changing from {old_odt} to {odt}"
)

# Up until now we did not modify the nodes/grap, just did
# some checks and derive the new bias and datatype. Start
# inserting this back into the graph now...

# Set new bias and datatype attributes into the threshold
# operator
threshold_op.set_nodeattr("out_bias", out_bias)
threshold_op.set_nodeattr("out_dtype", odt.name)
# Remove the bias operator and rewire the graph to skip the
# now-missing node
node.output[0] = consumer.output[0]
graph.node.remove(consumer)
# Update the datatype at the output of the threshold
# operation
model.set_tensor_datatype(node.output[0], odt)

# Graph modified so we need to apply this transformation
# again
graph_modified = True
# compute new DataType for MultiThreshold output
steps = T.shape[-1]
new_min = bias
new_max = steps + bias
odt = DataType.get_smallest_possible(steps).name.replace("UINT", "INT")
odt = DataType[odt]
assert odt.allowed(new_max) and odt.allowed(
new_min
), """Could
not compute new MultiThreshold DataType (min = %d max = %d)""" % (
new_min,
new_max,
)
mt_inst.set_nodeattr("out_dtype", odt.name)
# remove Add node, rewire MultiThreshold
graph.node.remove(add_node)
mt_node.output[0] = end_name
# set datatype
model.set_tensor_datatype(end_name, odt)
if graph_modified:
model = model.transform(InferDataTypes())
return (model, graph_modified)
# Better break now to clean up and recover annotations first
break
# As we might have changes types and removed nodes better redo some
# annotations
model = model.transform(InferDataTypes())
model = model.transform(InferShapes())
# Transformed model and indication whether the transformation should be
# applied again
return model, graph_modified


class AbsorbAddIntoMultiThreshold(Transformation):
Expand Down