-
Notifications
You must be signed in to change notification settings - Fork 528
/
Copy pathxnnpack_preprocess.py
196 lines (170 loc) · 7.46 KB
/
xnnpack_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Dict, final, List
import torch
from executorch.backends.xnnpack._passes import XNNPACKPassManager
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
TagImplicitQDqPass,
)
from executorch.backends.xnnpack.operators.node_visitor import get_node_visitors
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
ConstantDataOffset,
XNNGraph,
)
from executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import (
serialize_xnnpack_binary,
)
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
from executorch.backends.xnnpack.utils.utils import is_param_node
from executorch.backends.xnnpack.utils.xnnpack_constants import (
XNN_VALUE_FLAG_EXTERNAL_INPUT,
XNN_VALUE_FLAG_EXTERNAL_OUTPUT,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir.backend.backend_details import (
BackendDetails,
CompileSpec,
PreprocessResult,
)
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
from torch.export.exported_program import ExportedProgram
DEFAULT_DEBUG_HANDLE = 65535
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
@dataclass
class ExternalMeta:
external_id: int
io_type: int
def generate_node_to_external_map(
exported_program: ExportedProgram,
edge_graph_module: torch.fx.GraphModule,
) -> Dict[torch.fx.Node, ExternalMeta]:
node_to_external_map = {}
for node in edge_graph_module.graph.nodes:
# The order in which we visit the placeholder node is same as the *args
# order for the forward(*args) signature for this gm. Using the order of
# the nodes as external_id to extract the right arg from *args at runtime
#
# Removing parameters/buffers since they will disappear from the signature
# at runtime
if node.op == "placeholder" and not is_param_node(exported_program, node):
node_to_external_map[node] = ExternalMeta(
external_id=len(node_to_external_map),
io_type=XNN_VALUE_FLAG_EXTERNAL_INPUT,
)
for node in edge_graph_module.graph.nodes:
if node.op == "output":
for output_nodes in node.args:
for output_node in output_nodes:
node_to_external_map[output_node] = ExternalMeta(
external_id=len(node_to_external_map),
io_type=XNN_VALUE_FLAG_EXTERNAL_OUTPUT,
)
return node_to_external_map
def assert_default_dim_order(edge_graph_module: torch.fx.GraphModule) -> None:
for node in edge_graph_module.graph.nodes:
if node.op != "placeholder":
continue
# We expect the default dim order for all tensor-like inputs i.e. inputs, buffers, and params
t = node.meta.get("val", None)
if t is not None and getattr(t, "dim_order", None) is not None:
default_dim_order = tuple(range(t.dim()))
if t.dim_order() != default_dim_order:
raise RuntimeError(
f"XNNPACK backend only supports contiguous memory format for inputs."
f"Expecting dim_order: {default_dim_order}, but got {node.meta['val'].dim_order()} for a placeholder node {node}."
)
@final
class XnnpackBackend(BackendDetails):
@staticmethod
def preprocess(
edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> PreprocessResult:
named_data_store = NamedDataStore()
xnnpack_edge_compile_config = get_xnnpack_edge_compile_config()
# Need to wrap EP here because xnnpack does addmm to linear
# transforms. This makes resulting graph not aten compliant
# as aten.linear is not a core aten op.
# Ideal fix would be to have XNNPACK verifier that bypass
# most checks but the base Verifier itself has some strict changes
# and to bypass those, we would basically copy what EdgeDialectVerifier
# does. So for now instead of copy pasting that, just instantiate
# EdgeDialectVerifier, but disable it.
# TODO (task link) to implement NullVerifier or something similar
ep = ExportedProgram(
root=edge_program.graph_module,
graph=edge_program.graph,
graph_signature=edge_program.graph_signature,
state_dict=edge_program.state_dict,
range_constraints=edge_program.range_constraints,
module_call_graph=edge_program.module_call_graph,
example_inputs=edge_program.example_inputs,
constants=edge_program.constants,
verifiers=[
EXIREdgeDialectVerifier(
edge_compile_config=xnnpack_edge_compile_config, class_only=True
)
],
)
passes = []
for spec in compile_specs:
if spec.key == "dqlinear_partitioner":
passes.append(ConvertToLinearPass)
passes.append(TagImplicitQDqPass)
passes = passes if len(passes) > 0 else None
# XNNPACK Delegate Specific Passes
ep = XNNPACKPassManager(ep, passes=passes).transform()
graph_module = ep.graph_module
node_to_external_map = generate_node_to_external_map(ep, graph_module)
# Make sure all inputs are contiguous_format or NCHW or default dim order
assert_default_dim_order(graph_module)
# TODO retrace the graph module to lift the new params may have
# been added to the graph in passes
vals_to_ids = {}
xnnpack_graph = XNNGraph(
version="0",
xnodes=[],
xvalues=[],
num_externs=len(node_to_external_map),
input_ids=[],
output_ids=[],
constant_data=[ConstantDataOffset(0, 0)],
)
constant_data_bytes = bytearray()
node_visitors = get_node_visitors(ep, node_to_external_map, named_data_store)
for node in graph_module.graph.nodes:
if node.op == "call_function":
logger.info(f"Visiting: {node}, {node.target.__name__}")
if node.target.__name__ in node_visitors:
node_visitors[node.target.__name__].define_node(
node,
xnnpack_graph,
vals_to_ids,
node.meta.get("debug_handle", DEFAULT_DEBUG_HANDLE),
)
else:
raise RuntimeError(
f"For {node}, {node.op}:{node.target.__name__} is not supported in XNNPACK Delegate"
)
elif node.op in [
"get_attr",
"placeholder",
"output",
]:
continue
else:
raise RuntimeError(f"{node.op} is not supported in XNNPACK")
return PreprocessResult(
processed_bytes=serialize_xnnpack_binary(
xnnpack_graph, constant_data_bytes
),
debug_handle_map={},
data_store_output=named_data_store.get_named_data_store_output(),
)