Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ExecuTorch][Weight Sharing][XNNPACK] Serialize constant tensors into named data map #9153

Merged
merged 5 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/xnnpack/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ python_library(
"//executorch/exir/passes:const_prop_pass",
"//executorch/exir/passes:memory_format_ops_pass",
"//executorch/exir/program:program",
"//executorch/backends/transforms:utils",
],
)
68 changes: 52 additions & 16 deletions backends/xnnpack/_passes/fuse_batch_norm_with_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,22 @@
import operator

import torch
from executorch.backends.transforms.utils import (
create_constant_placeholder,
delete_constant_placeholder,
)

from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass

from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node
from executorch.backends.xnnpack.utils.utils import (
get_param_tensor,
get_tensor_name,
is_param_node,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
from torch.export.graph_signature import InputKind

from torch.nn.utils.fusion import fuse_conv_bn_weights

Expand All @@ -28,7 +37,7 @@ class FuseBatchNormWithConvPass(XNNPACKPass):

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
counter = 0
constant_placeholders_to_delete = set()
for conv in graph.nodes:
# We want to discover a chain of conv -> batch_norm.
# Only proceed if the current node is a conv node, and has a single
Expand All @@ -55,9 +64,11 @@ def call(self, graph_module: torch.fx.GraphModule):
assert len(conv.args) == 9

conv_weight = get_param_tensor(self.exported_program, conv.args[1])
conv_weight_name = get_tensor_name(self.exported_program, conv.args[1])
assert conv_weight is not None

conv_bias = get_param_tensor(self.exported_program, conv.args[2])
conv_bias_name = get_tensor_name(self.exported_program, conv.args[2])

# Get the parameters from the batchnorm op
assert (
Expand Down Expand Up @@ -95,32 +106,57 @@ def call(self, graph_module: torch.fx.GraphModule):
bn_bias,
is_transpose,
)
fused_weight_name = (conv_weight_name + "_fused_bn").replace(".", "_")
if conv_bias_name == "":
fused_bias_name = (conv_weight_name + "_bias_fused_bn").replace(
".", "_"
)
else:
fused_bias_name = (conv_bias_name + "_fused_bn").replace(".", "_")

# Modify the graph by updating the weight and bias of conv op
# with the fused weight and bias params, and replacing all the users
# of getitem(batchnorm) with the conv op.
with graph.inserting_before(conv):
fused_weight_name = f"_fused_with_bn_weight_{counter}"
graph_module.register_parameter(fused_weight_name, fused_weight)
fused_weight_node = graph.get_attr(fused_weight_name)
fused_bias_name = f"_fused_with_bn_bias_{counter}"
graph_module.register_parameter(fused_bias_name, fused_bias)
fused_bias_node = graph.get_attr(fused_bias_name)

# Update the weight and bias of conv op
conv_args = list(conv.args) + ([None] if len(conv.args) == 2 else [])
conv_args[1] = fused_weight_node
conv_args[2] = fused_bias_node
conv.args = tuple(conv_args)
with graph.inserting_before(conv.args[1]):
fused_conv_weight_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
name=fused_weight_name,
data=fused_weight,
)
if fused_bias is not None:
fused_conv_bias_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
name=fused_bias_name,
data=fused_bias,
)
else:
fused_conv_bias_node = None

conv.args = (
conv.args[0],
fused_conv_weight_node,
fused_conv_bias_node,
*conv.args[3:],
)

# Remove any use of batchnorm from the graph
for user in bn.users.copy():
assert user.target == operator.getitem
user.replace_all_uses_with(conv)
graph.erase_node(user)

graph.erase_node(bn)
constant_placeholders_to_delete.update(conv.args[1:3] + bn.args[1:5])

counter += 1
if len(constant_placeholders_to_delete) > 0:
graph_module.graph.eliminate_dead_code()
for node in constant_placeholders_to_delete:
if (node is not None) and (len(node.users) == 0):
delete_constant_placeholder(self.exported_program, node)

graph_module.recompile()
# To Regenerate meta data and shape information, retrace module
Expand Down
26 changes: 17 additions & 9 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,23 @@
check_or_raise,
get_input_node,
get_param_tensor,
get_tensor_name,
is_param_node,
PERM_NCHW_TO_NHWC,
)

from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID
from executorch.backends.xnnpack.utils.xnnpack_constants import (
UINT64_MAX,
XNN_INVALID_VALUE_ID,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from torch.export import ExportedProgram

XNN_TYPE_MAP = {
torch.float32: XNNDatatype.xnn_datatype_fp32,
}

from executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import (
_aligned_size,
_pad_to,
CONSTANT_TENSOR_ALIGNMENT,
)

Expand Down Expand Up @@ -86,11 +89,11 @@ def __init__(
self,
exported_program: ExportedProgram,
external_ids: Dict,
constant_data_bytes: bytearray,
named_data_store: NamedDataStore,
) -> None:
self._external_ids = external_ids or {}
self._exported_program = exported_program or None
self._constant_data_bytes = constant_data_bytes
self._named_data_store = named_data_store

@property
def external_ids(self) -> Dict:
Expand Down Expand Up @@ -579,11 +582,16 @@ def get_serialized_buffer_index(
ctypes.POINTER(array_type),
).contents

offset = len(self._constant_data_bytes)
named_key = get_tensor_name(self.exported_program, get_attr_node)
if named_key == "":
raise ValueError(f"Tensor from node: {get_attr_node} has no name")

size = const_val.untyped_storage().nbytes()
xnn_graph.constant_data.append(ConstantDataOffset(offset=offset, size=size))
self._constant_data_bytes.extend(
_pad_to(bytes(array), _aligned_size(size, CONSTANT_TENSOR_ALIGNMENT))
xnn_graph.constant_data.append(
ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)
)
self._named_data_store.add_named_data(
named_key, bytes(array), alignment=CONSTANT_TENSOR_ALIGNMENT
)

return buffer_idx
Expand Down
9 changes: 9 additions & 0 deletions backends/xnnpack/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,20 @@ table XNNLeakyReLU {
table ConstantDataOffset {
// Constant data offsets are relative to the constant data base offset provided
// in the XNNPACKHeader.
// named_key and offset are mutually exclusive, meaning only one of these values
// are valid. If the named key is a non-empty string, then the offset must be UINT64_MAX.
// If the offset is not UINT64_MAX, then the named key must be an empty string
offset: uint64;

// The size in bytes of valid data starting at the offset. The constant data
// may be followed by padding before the next piece of constant data
size: uint64;

// unique string id used to query the offset from the named data store.
// named_key and offset are mutually exclusive, meaning only one of these values
// are valid. If the named key is a non-empty string, then the offset must be UINT64_MAX.
// If the offset is not UINT64_MAX, then the named key must be an empty string
named_key: string;
}

table XNNGraph {
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/serialization/xnnpack_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ class XValue:
class ConstantDataOffset:
offset: int
size: int
named_key: str = ""


@dataclass
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/utils/gen_xnnpack_constants.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
} > xnnpack_constants.py

echo UINT32_MAX = 4294967295 >> xnnpack_constants.py
echo UINT64_MAX = 18446744073709551615 >> xnnpack_constants.py
awk '/^#define\s+XNN_/ { print $2,"=",$3} ' "$1"/include/xnnpack.h >> xnnpack_constants.py
if ! grep -qc "^XNN_" xnnpack_constants.py; then false; fi
16 changes: 16 additions & 0 deletions backends/xnnpack/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,22 @@ def get_param_tensor(
raise RuntimeError(f"unsupported param type, {node.op}.")


def get_tensor_name(exp_prog: ExportedProgram, node: torch.fx.Node) -> str:
if node is None:
return ""
if is_param(exp_prog, node):
return exp_prog.graph_signature.inputs_to_parameters[node.name]
elif is_buffer(exp_prog, node):
return exp_prog.graph_signature.inputs_to_buffers[node.name]
elif is_lifted_tensor_constant(exp_prog, node):
return exp_prog.graph_signature.inputs_to_lifted_tensor_constants[node.name]
else:
assert isinstance(node.target, str)
return node.target

return ""


def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]:
"""
Returns the source fn of the given node, return None if something goes wrong
Expand Down
6 changes: 5 additions & 1 deletion backends/xnnpack/utils/xnnpack_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

# Auto-generated by gen_xnnpack_constants.sh script. Do not modify
UINT32_MAX = 4294967295
UINT64_MAX = 18446744073709551615
XNN_EXTRA_BYTES = 128
XNN_EXTRA_BYTES = 16
XNN_MAX_TENSOR_DIMS = 6
XNN_INVALID_VALUE_ID = UINT32_MAX
XNN_FLAG_HINT_SPARSE_INFERENCE = 0x00000001
XNN_FLAG_HINT_FP16_INFERENCE = 0x00000002
XNN_FLAG_FORCE_FP16_INFERENCE = 0x00000004
Expand All @@ -26,7 +29,8 @@
XNN_FLAG_YIELD_WORKERS = 0x00000010
XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER = 0x00000020
XNN_FLAG_KEEP_DIMS = 0x00000040
XNN_EXTRA_QUANTIZATION_PARAMS = 8
XNN_EXTRA_QUANTIZATION_PARAMS = 10
XNN_MIN_BLOCKSIZE = 32
XNN_VALUE_FLAG_EXTERNAL_INPUT = 0x00000001
XNN_VALUE_FLAG_EXTERNAL_OUTPUT = 0x00000002
XNN_VALUE_FLAG_PERSISTENT = 0x00000004
Expand Down
6 changes: 4 additions & 2 deletions backends/xnnpack/xnnpack_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
XNN_VALUE_FLAG_EXTERNAL_INPUT,
XNN_VALUE_FLAG_EXTERNAL_OUTPUT,
)
from executorch.exir._serialize._named_data_store import NamedDataStore

from executorch.exir.backend.backend_details import (
BackendDetails,
Expand Down Expand Up @@ -103,7 +104,7 @@ def preprocess(
edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> PreprocessResult:

named_data_store = NamedDataStore()
xnnpack_edge_compile_config = get_xnnpack_edge_compile_config()

# Need to wrap EP here because xnnpack does addmm to linear
Expand Down Expand Up @@ -162,7 +163,7 @@ def preprocess(
)

constant_data_bytes = bytearray()
node_visitors = get_node_visitors(ep, node_to_external_map, constant_data_bytes)
node_visitors = get_node_visitors(ep, node_to_external_map, named_data_store)

for node in graph_module.graph.nodes:
if node.op == "call_function":
Expand Down Expand Up @@ -191,4 +192,5 @@ def preprocess(
xnnpack_graph, constant_data_bytes
),
debug_handle_map={},
data_store_output=named_data_store.get_named_data_store_output(),
)
Loading