Skip to content

Commit 5f05cf7

Browse files
pytorchbotmcr229
andauthored
[ExecuTorch][Weight Sharing][XNNPACK] Serialize constant tensors into named data map (#9295)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #9153 by @mcr229 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/mcr229/9/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/mcr229/9/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/mcr229/8/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/mcr229/9/orig @diff-train-skip-merge Co-authored-by: Max Ren <[email protected]>
1 parent 6ab0019 commit 5f05cf7

File tree

9 files changed

+106
-28
lines changed

9 files changed

+106
-28
lines changed

backends/xnnpack/_passes/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ python_library(
1919
"//executorch/exir/passes:const_prop_pass",
2020
"//executorch/exir/passes:memory_format_ops_pass",
2121
"//executorch/exir/program:program",
22+
"//executorch/backends/transforms:utils",
2223
],
2324
)

backends/xnnpack/_passes/fuse_batch_norm_with_conv.py

+52-16
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,22 @@
77
import operator
88

99
import torch
10+
from executorch.backends.transforms.utils import (
11+
create_constant_placeholder,
12+
delete_constant_placeholder,
13+
)
1014

1115
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
1216

13-
from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node
17+
from executorch.backends.xnnpack.utils.utils import (
18+
get_param_tensor,
19+
get_tensor_name,
20+
is_param_node,
21+
)
1422
from executorch.exir import ExportedProgram
1523
from executorch.exir.dialects._ops import ops as exir_ops
1624
from executorch.exir.pass_base import PassResult
25+
from torch.export.graph_signature import InputKind
1726

1827
from torch.nn.utils.fusion import fuse_conv_bn_weights
1928

@@ -28,7 +37,7 @@ class FuseBatchNormWithConvPass(XNNPACKPass):
2837

2938
def call(self, graph_module: torch.fx.GraphModule):
3039
graph = graph_module.graph
31-
counter = 0
40+
constant_placeholders_to_delete = set()
3241
for conv in graph.nodes:
3342
# We want to discover a chain of conv -> batch_norm.
3443
# Only proceed if the current node is a conv node, and has a single
@@ -55,9 +64,11 @@ def call(self, graph_module: torch.fx.GraphModule):
5564
assert len(conv.args) == 9
5665

5766
conv_weight = get_param_tensor(self.exported_program, conv.args[1])
67+
conv_weight_name = get_tensor_name(self.exported_program, conv.args[1])
5868
assert conv_weight is not None
5969

6070
conv_bias = get_param_tensor(self.exported_program, conv.args[2])
71+
conv_bias_name = get_tensor_name(self.exported_program, conv.args[2])
6172

6273
# Get the parameters from the batchnorm op
6374
assert (
@@ -95,32 +106,57 @@ def call(self, graph_module: torch.fx.GraphModule):
95106
bn_bias,
96107
is_transpose,
97108
)
109+
fused_weight_name = (conv_weight_name + "_fused_bn").replace(".", "_")
110+
if conv_bias_name == "":
111+
fused_bias_name = (conv_weight_name + "_bias_fused_bn").replace(
112+
".", "_"
113+
)
114+
else:
115+
fused_bias_name = (conv_bias_name + "_fused_bn").replace(".", "_")
98116

99117
# Modify the graph by updating the weight and bias of conv op
100118
# with the fused weight and bias params, and replacing all the users
101119
# of getitem(batchnorm) with the conv op.
102-
with graph.inserting_before(conv):
103-
fused_weight_name = f"_fused_with_bn_weight_{counter}"
104-
graph_module.register_parameter(fused_weight_name, fused_weight)
105-
fused_weight_node = graph.get_attr(fused_weight_name)
106-
fused_bias_name = f"_fused_with_bn_bias_{counter}"
107-
graph_module.register_parameter(fused_bias_name, fused_bias)
108-
fused_bias_node = graph.get_attr(fused_bias_name)
109-
110-
# Update the weight and bias of conv op
111-
conv_args = list(conv.args) + ([None] if len(conv.args) == 2 else [])
112-
conv_args[1] = fused_weight_node
113-
conv_args[2] = fused_bias_node
114-
conv.args = tuple(conv_args)
120+
with graph.inserting_before(conv.args[1]):
121+
fused_conv_weight_node = create_constant_placeholder(
122+
exp_program=self.exported_program,
123+
graph=graph_module.graph,
124+
kind=InputKind.PARAMETER,
125+
name=fused_weight_name,
126+
data=fused_weight,
127+
)
128+
if fused_bias is not None:
129+
fused_conv_bias_node = create_constant_placeholder(
130+
exp_program=self.exported_program,
131+
graph=graph_module.graph,
132+
kind=InputKind.PARAMETER,
133+
name=fused_bias_name,
134+
data=fused_bias,
135+
)
136+
else:
137+
fused_conv_bias_node = None
138+
139+
conv.args = (
140+
conv.args[0],
141+
fused_conv_weight_node,
142+
fused_conv_bias_node,
143+
*conv.args[3:],
144+
)
145+
115146
# Remove any use of batchnorm from the graph
116147
for user in bn.users.copy():
117148
assert user.target == operator.getitem
118149
user.replace_all_uses_with(conv)
119150
graph.erase_node(user)
120151

121152
graph.erase_node(bn)
153+
constant_placeholders_to_delete.update(conv.args[1:3] + bn.args[1:5])
122154

123-
counter += 1
155+
if len(constant_placeholders_to_delete) > 0:
156+
graph_module.graph.eliminate_dead_code()
157+
for node in constant_placeholders_to_delete:
158+
if (node is not None) and (len(node.users) == 0):
159+
delete_constant_placeholder(self.exported_program, node)
124160

125161
graph_module.recompile()
126162
# To Regenerate meta data and shape information, retrace module

backends/xnnpack/operators/node_visitor.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,23 @@
3434
check_or_raise,
3535
get_input_node,
3636
get_param_tensor,
37+
get_tensor_name,
3738
is_param_node,
3839
PERM_NCHW_TO_NHWC,
3940
)
4041

41-
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID
42+
from executorch.backends.xnnpack.utils.xnnpack_constants import (
43+
UINT64_MAX,
44+
XNN_INVALID_VALUE_ID,
45+
)
46+
from executorch.exir._serialize._named_data_store import NamedDataStore
4247
from torch.export import ExportedProgram
4348

4449
XNN_TYPE_MAP = {
4550
torch.float32: XNNDatatype.xnn_datatype_fp32,
4651
}
4752

4853
from executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import (
49-
_aligned_size,
50-
_pad_to,
5154
CONSTANT_TENSOR_ALIGNMENT,
5255
)
5356

@@ -86,11 +89,11 @@ def __init__(
8689
self,
8790
exported_program: ExportedProgram,
8891
external_ids: Dict,
89-
constant_data_bytes: bytearray,
92+
named_data_store: NamedDataStore,
9093
) -> None:
9194
self._external_ids = external_ids or {}
9295
self._exported_program = exported_program or None
93-
self._constant_data_bytes = constant_data_bytes
96+
self._named_data_store = named_data_store
9497

9598
@property
9699
def external_ids(self) -> Dict:
@@ -579,11 +582,16 @@ def get_serialized_buffer_index(
579582
ctypes.POINTER(array_type),
580583
).contents
581584

582-
offset = len(self._constant_data_bytes)
585+
named_key = get_tensor_name(self.exported_program, get_attr_node)
586+
if named_key == "":
587+
raise ValueError(f"Tensor from node: {get_attr_node} has no name")
588+
583589
size = const_val.untyped_storage().nbytes()
584-
xnn_graph.constant_data.append(ConstantDataOffset(offset=offset, size=size))
585-
self._constant_data_bytes.extend(
586-
_pad_to(bytes(array), _aligned_size(size, CONSTANT_TENSOR_ALIGNMENT))
590+
xnn_graph.constant_data.append(
591+
ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)
592+
)
593+
self._named_data_store.add_named_data(
594+
named_key, bytes(array), alignment=CONSTANT_TENSOR_ALIGNMENT
587595
)
588596

589597
return buffer_idx

backends/xnnpack/serialization/schema.fbs

+9
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,20 @@ table XNNLeakyReLU {
316316
table ConstantDataOffset {
317317
// Constant data offsets are relative to the constant data base offset provided
318318
// in the XNNPACKHeader.
319+
// named_key and offset are mutually exclusive, meaning only one of these values
320+
// are valid. If the named key is a non-empty string, then the offset must be UINT64_MAX.
321+
// If the offset is not UINT64_MAX, then the named key must be an empty string
319322
offset: uint64;
320323

321324
// The size in bytes of valid data starting at the offset. The constant data
322325
// may be followed by padding before the next piece of constant data
323326
size: uint64;
327+
328+
// unique string id used to query the offset from the named data store.
329+
// named_key and offset are mutually exclusive, meaning only one of these values
330+
// are valid. If the named key is a non-empty string, then the offset must be UINT64_MAX.
331+
// If the offset is not UINT64_MAX, then the named key must be an empty string
332+
named_key: string;
324333
}
325334

326335
table XNNGraph {

backends/xnnpack/serialization/xnnpack_graph_schema.py

+1
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ class XValue:
470470
class ConstantDataOffset:
471471
offset: int
472472
size: int
473+
named_key: str = ""
473474

474475

475476
@dataclass

backends/xnnpack/utils/gen_xnnpack_constants.sh

+1
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,6 @@
2626
} > xnnpack_constants.py
2727

2828
echo UINT32_MAX = 4294967295 >> xnnpack_constants.py
29+
echo UINT64_MAX = 18446744073709551615 >> xnnpack_constants.py
2930
awk '/^#define\s+XNN_/ { print $2,"=",$3} ' "$1"/include/xnnpack.h >> xnnpack_constants.py
3031
if ! grep -qc "^XNN_" xnnpack_constants.py; then false; fi

backends/xnnpack/utils/utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,22 @@ def get_param_tensor(
131131
raise RuntimeError(f"unsupported param type, {node.op}.")
132132

133133

134+
def get_tensor_name(exp_prog: ExportedProgram, node: torch.fx.Node) -> str:
135+
if node is None:
136+
return ""
137+
if is_param(exp_prog, node):
138+
return exp_prog.graph_signature.inputs_to_parameters[node.name]
139+
elif is_buffer(exp_prog, node):
140+
return exp_prog.graph_signature.inputs_to_buffers[node.name]
141+
elif is_lifted_tensor_constant(exp_prog, node):
142+
return exp_prog.graph_signature.inputs_to_lifted_tensor_constants[node.name]
143+
else:
144+
assert isinstance(node.target, str)
145+
return node.target
146+
147+
return ""
148+
149+
134150
def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]:
135151
"""
136152
Returns the source fn of the given node, return None if something goes wrong

backends/xnnpack/utils/xnnpack_constants.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66

77
# Auto-generated by gen_xnnpack_constants.sh script. Do not modify
88
UINT32_MAX = 4294967295
9+
UINT64_MAX = 18446744073709551615
10+
XNN_EXTRA_BYTES = 128
911
XNN_EXTRA_BYTES = 16
1012
XNN_MAX_TENSOR_DIMS = 6
13+
XNN_INVALID_VALUE_ID = UINT32_MAX
1114
XNN_FLAG_HINT_SPARSE_INFERENCE = 0x00000001
1215
XNN_FLAG_HINT_FP16_INFERENCE = 0x00000002
1316
XNN_FLAG_FORCE_FP16_INFERENCE = 0x00000004
@@ -26,7 +29,8 @@
2629
XNN_FLAG_YIELD_WORKERS = 0x00000010
2730
XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER = 0x00000020
2831
XNN_FLAG_KEEP_DIMS = 0x00000040
29-
XNN_EXTRA_QUANTIZATION_PARAMS = 8
32+
XNN_EXTRA_QUANTIZATION_PARAMS = 10
33+
XNN_MIN_BLOCKSIZE = 32
3034
XNN_VALUE_FLAG_EXTERNAL_INPUT = 0x00000001
3135
XNN_VALUE_FLAG_EXTERNAL_OUTPUT = 0x00000002
3236
XNN_VALUE_FLAG_PERSISTENT = 0x00000004

backends/xnnpack/xnnpack_preprocess.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
XNN_VALUE_FLAG_EXTERNAL_INPUT,
3232
XNN_VALUE_FLAG_EXTERNAL_OUTPUT,
3333
)
34+
from executorch.exir._serialize._named_data_store import NamedDataStore
3435

3536
from executorch.exir.backend.backend_details import (
3637
BackendDetails,
@@ -103,7 +104,7 @@ def preprocess(
103104
edge_program: ExportedProgram,
104105
compile_specs: List[CompileSpec],
105106
) -> PreprocessResult:
106-
107+
named_data_store = NamedDataStore()
107108
xnnpack_edge_compile_config = get_xnnpack_edge_compile_config()
108109

109110
# Need to wrap EP here because xnnpack does addmm to linear
@@ -162,7 +163,7 @@ def preprocess(
162163
)
163164

164165
constant_data_bytes = bytearray()
165-
node_visitors = get_node_visitors(ep, node_to_external_map, constant_data_bytes)
166+
node_visitors = get_node_visitors(ep, node_to_external_map, named_data_store)
166167

167168
for node in graph_module.graph.nodes:
168169
if node.op == "call_function":
@@ -191,4 +192,5 @@ def preprocess(
191192
xnnpack_graph, constant_data_bytes
192193
),
193194
debug_handle_map={},
195+
data_store_output=named_data_store.get_named_data_store_output(),
194196
)

0 commit comments

Comments
 (0)