7
7
import operator
8
8
9
9
import torch
10
+ from executorch .backends .transforms .utils import (
11
+ create_constant_placeholder ,
12
+ delete_constant_placeholder ,
13
+ )
10
14
11
15
from executorch .backends .xnnpack ._passes .xnnpack_pass import XNNPACKPass
12
16
13
- from executorch .backends .xnnpack .utils .utils import get_param_tensor , is_param_node
17
+ from executorch .backends .xnnpack .utils .utils import (
18
+ get_param_tensor ,
19
+ get_tensor_name ,
20
+ is_param_node ,
21
+ )
14
22
from executorch .exir import ExportedProgram
15
23
from executorch .exir .dialects ._ops import ops as exir_ops
16
24
from executorch .exir .pass_base import PassResult
25
+ from torch .export .graph_signature import InputKind
17
26
18
27
from torch .nn .utils .fusion import fuse_conv_bn_weights
19
28
@@ -28,7 +37,7 @@ class FuseBatchNormWithConvPass(XNNPACKPass):
28
37
29
38
def call (self , graph_module : torch .fx .GraphModule ):
30
39
graph = graph_module .graph
31
- counter = 0
40
+ constant_placeholders_to_delete = set ()
32
41
for conv in graph .nodes :
33
42
# We want to discover a chain of conv -> batch_norm.
34
43
# Only proceed if the current node is a conv node, and has a single
@@ -55,9 +64,11 @@ def call(self, graph_module: torch.fx.GraphModule):
55
64
assert len (conv .args ) == 9
56
65
57
66
conv_weight = get_param_tensor (self .exported_program , conv .args [1 ])
67
+ conv_weight_name = get_tensor_name (self .exported_program , conv .args [1 ])
58
68
assert conv_weight is not None
59
69
60
70
conv_bias = get_param_tensor (self .exported_program , conv .args [2 ])
71
+ conv_bias_name = get_tensor_name (self .exported_program , conv .args [2 ])
61
72
62
73
# Get the parameters from the batchnorm op
63
74
assert (
@@ -95,32 +106,57 @@ def call(self, graph_module: torch.fx.GraphModule):
95
106
bn_bias ,
96
107
is_transpose ,
97
108
)
109
+ fused_weight_name = (conv_weight_name + "_fused_bn" ).replace ("." , "_" )
110
+ if conv_bias_name == "" :
111
+ fused_bias_name = (conv_weight_name + "_bias_fused_bn" ).replace (
112
+ "." , "_"
113
+ )
114
+ else :
115
+ fused_bias_name = (conv_bias_name + "_fused_bn" ).replace ("." , "_" )
98
116
99
117
# Modify the graph by updating the weight and bias of conv op
100
118
# with the fused weight and bias params, and replacing all the users
101
119
# of getitem(batchnorm) with the conv op.
102
- with graph .inserting_before (conv ):
103
- fused_weight_name = f"_fused_with_bn_weight_{ counter } "
104
- graph_module .register_parameter (fused_weight_name , fused_weight )
105
- fused_weight_node = graph .get_attr (fused_weight_name )
106
- fused_bias_name = f"_fused_with_bn_bias_{ counter } "
107
- graph_module .register_parameter (fused_bias_name , fused_bias )
108
- fused_bias_node = graph .get_attr (fused_bias_name )
109
-
110
- # Update the weight and bias of conv op
111
- conv_args = list (conv .args ) + ([None ] if len (conv .args ) == 2 else [])
112
- conv_args [1 ] = fused_weight_node
113
- conv_args [2 ] = fused_bias_node
114
- conv .args = tuple (conv_args )
120
+ with graph .inserting_before (conv .args [1 ]):
121
+ fused_conv_weight_node = create_constant_placeholder (
122
+ exp_program = self .exported_program ,
123
+ graph = graph_module .graph ,
124
+ kind = InputKind .PARAMETER ,
125
+ name = fused_weight_name ,
126
+ data = fused_weight ,
127
+ )
128
+ if fused_bias is not None :
129
+ fused_conv_bias_node = create_constant_placeholder (
130
+ exp_program = self .exported_program ,
131
+ graph = graph_module .graph ,
132
+ kind = InputKind .PARAMETER ,
133
+ name = fused_bias_name ,
134
+ data = fused_bias ,
135
+ )
136
+ else :
137
+ fused_conv_bias_node = None
138
+
139
+ conv .args = (
140
+ conv .args [0 ],
141
+ fused_conv_weight_node ,
142
+ fused_conv_bias_node ,
143
+ * conv .args [3 :],
144
+ )
145
+
115
146
# Remove any use of batchnorm from the graph
116
147
for user in bn .users .copy ():
117
148
assert user .target == operator .getitem
118
149
user .replace_all_uses_with (conv )
119
150
graph .erase_node (user )
120
151
121
152
graph .erase_node (bn )
153
+ constant_placeholders_to_delete .update (conv .args [1 :3 ] + bn .args [1 :5 ])
122
154
123
- counter += 1
155
+ if len (constant_placeholders_to_delete ) > 0 :
156
+ graph_module .graph .eliminate_dead_code ()
157
+ for node in constant_placeholders_to_delete :
158
+ if (node is not None ) and (len (node .users ) == 0 ):
159
+ delete_constant_placeholder (self .exported_program , node )
124
160
125
161
graph_module .recompile ()
126
162
# To Regenerate meta data and shape information, retrace module
0 commit comments