Skip to content

[XNNPACK][Weights Cache] Enable in XNNPACK #9297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backends/xnnpack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ option(EXECUTORCH_XNNPACK_SHARED_WORKSPACE
# Keeping this OFF by default due to regressions in decode and model load with
# kleidi kernels
option(EXECUTORCH_XNNPACK_ENABLE_KLEIDI "Enable Arm Kleidi kernels" OFF)

# Turning this on cache weights between partitions and methods. If weights
# are shared across methods/partitions then this can reduce load time and
# memory usage

# Keeping this off maintains existing behavior. Turning this on serializes
# execution and initialization of delegates, to be revisited
option(EXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE
"Enable weights cache to cache and manage all packed weights" OFF)

if(EXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE)
add_definitions(-DENABLE_XNNPACK_WEIGHTS_CACHE)
endif()
if(EXECUTORCH_XNNPACK_SHARED_WORKSPACE)
add_definitions(-DENABLE_XNNPACK_SHARED_WORKSPACE)
endif()
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ python_library(
"//executorch/exir/passes:const_prop_pass",
"//executorch/exir/passes:memory_format_ops_pass",
"//executorch/exir/program:program",
"//executorch/backends/transforms:utils",
],
)
68 changes: 52 additions & 16 deletions backends/xnnpack/_passes/fuse_batch_norm_with_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,22 @@
import operator

import torch
from executorch.backends.transforms.utils import (
create_constant_placeholder,
delete_constant_placeholder,
)

from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass

from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node
from executorch.backends.xnnpack.utils.utils import (
get_param_tensor,
get_tensor_name,
is_param_node,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
from torch.export.graph_signature import InputKind

from torch.nn.utils.fusion import fuse_conv_bn_weights

Expand All @@ -28,7 +37,7 @@ class FuseBatchNormWithConvPass(XNNPACKPass):

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
counter = 0
constant_placeholders_to_delete = set()
for conv in graph.nodes:
# We want to discover a chain of conv -> batch_norm.
# Only proceed if the current node is a conv node, and has a single
Expand All @@ -55,9 +64,11 @@ def call(self, graph_module: torch.fx.GraphModule):
assert len(conv.args) == 9

conv_weight = get_param_tensor(self.exported_program, conv.args[1])
conv_weight_name = get_tensor_name(self.exported_program, conv.args[1])
assert conv_weight is not None

conv_bias = get_param_tensor(self.exported_program, conv.args[2])
conv_bias_name = get_tensor_name(self.exported_program, conv.args[2])

# Get the parameters from the batchnorm op
assert (
Expand Down Expand Up @@ -95,32 +106,57 @@ def call(self, graph_module: torch.fx.GraphModule):
bn_bias,
is_transpose,
)
fused_weight_name = (conv_weight_name + "_fused_bn").replace(".", "_")
if conv_bias_name == "":
fused_bias_name = (conv_weight_name + "_bias_fused_bn").replace(
".", "_"
)
else:
fused_bias_name = (conv_bias_name + "_fused_bn").replace(".", "_")

# Modify the graph by updating the weight and bias of conv op
# with the fused weight and bias params, and replacing all the users
# of getitem(batchnorm) with the conv op.
with graph.inserting_before(conv):
fused_weight_name = f"_fused_with_bn_weight_{counter}"
graph_module.register_parameter(fused_weight_name, fused_weight)
fused_weight_node = graph.get_attr(fused_weight_name)
fused_bias_name = f"_fused_with_bn_bias_{counter}"
graph_module.register_parameter(fused_bias_name, fused_bias)
fused_bias_node = graph.get_attr(fused_bias_name)

# Update the weight and bias of conv op
conv_args = list(conv.args) + ([None] if len(conv.args) == 2 else [])
conv_args[1] = fused_weight_node
conv_args[2] = fused_bias_node
conv.args = tuple(conv_args)
with graph.inserting_before(conv.args[1]):
fused_conv_weight_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
name=fused_weight_name,
data=fused_weight,
)
if fused_bias is not None:
fused_conv_bias_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
name=fused_bias_name,
data=fused_bias,
)
else:
fused_conv_bias_node = None

conv.args = (
conv.args[0],
fused_conv_weight_node,
fused_conv_bias_node,
*conv.args[3:],
)

# Remove any use of batchnorm from the graph
for user in bn.users.copy():
assert user.target == operator.getitem
user.replace_all_uses_with(conv)
graph.erase_node(user)

graph.erase_node(bn)
constant_placeholders_to_delete.update(conv.args[1:3] + bn.args[1:5])

counter += 1
if len(constant_placeholders_to_delete) > 0:
graph_module.graph.eliminate_dead_code()
for node in constant_placeholders_to_delete:
if (node is not None) and (len(node.users) == 0):
delete_constant_placeholder(self.exported_program, node)

graph_module.recompile()
# To Regenerate meta data and shape information, retrace module
Expand Down
26 changes: 17 additions & 9 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,23 @@
check_or_raise,
get_input_node,
get_param_tensor,
get_tensor_name,
is_param_node,
PERM_NCHW_TO_NHWC,
)

from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID
from executorch.backends.xnnpack.utils.xnnpack_constants import (
UINT64_MAX,
XNN_INVALID_VALUE_ID,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from torch.export import ExportedProgram

XNN_TYPE_MAP = {
torch.float32: XNNDatatype.xnn_datatype_fp32,
}

from executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import (
_aligned_size,
_pad_to,
CONSTANT_TENSOR_ALIGNMENT,
)

Expand Down Expand Up @@ -86,11 +89,11 @@ def __init__(
self,
exported_program: ExportedProgram,
external_ids: Dict,
constant_data_bytes: bytearray,
named_data_store: NamedDataStore,
) -> None:
self._external_ids = external_ids or {}
self._exported_program = exported_program or None
self._constant_data_bytes = constant_data_bytes
self._named_data_store = named_data_store

@property
def external_ids(self) -> Dict:
Expand Down Expand Up @@ -579,11 +582,16 @@ def get_serialized_buffer_index(
ctypes.POINTER(array_type),
).contents

offset = len(self._constant_data_bytes)
named_key = get_tensor_name(self.exported_program, get_attr_node)
if named_key == "":
raise ValueError(f"Tensor from node: {get_attr_node} has no name")

size = const_val.untyped_storage().nbytes()
xnn_graph.constant_data.append(ConstantDataOffset(offset=offset, size=size))
self._constant_data_bytes.extend(
_pad_to(bytes(array), _aligned_size(size, CONSTANT_TENSOR_ALIGNMENT))
xnn_graph.constant_data.append(
ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)
)
self._named_data_store.add_named_data(
named_key, bytes(array), alignment=CONSTANT_TENSOR_ALIGNMENT
)

return buffer_idx
Expand Down
72 changes: 60 additions & 12 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
#include <executorch/backends/xnnpack/serialization/schema_generated.h>
#include <executorch/extension/threadpool/threadpool.h>
#include <executorch/runtime/executor/pte_data_map.h>
#include <string>
#include <unordered_map>
#include <vector>

#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wglobal-constructors"
Expand Down Expand Up @@ -167,7 +169,8 @@ const uint8_t* getConstantDataPtr(
GraphPtr flatbuffer_graph,
const uint8_t* constant_data_ptr,
const NamedDataMap* named_data_map,
std::vector<FreeableBuffer>& loaded_buffers_from_map) {
std::vector<FreeableBuffer>& freeable_buffers,
XNNWeightsCache* weights_cache) {
auto buffer_idx = tensor_value->constant_buffer_idx();
if (buffer_idx) {
if (!constant_data_ptr) {
Expand All @@ -187,6 +190,15 @@ const uint8_t* getConstantDataPtr(
return constant_data_ptr + offset;
} else {
const std::string& data_name = constant_data_offset->named_key()->str();
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
Result<const uint8_t*> data_ptr =
weights_cache->load_unpacked_data(data_name);
if (!data_ptr.ok()) {
ET_LOG(Error, "Failed to load weights from cache");
return nullptr;
}
return data_ptr.get();
#else
Result<FreeableBuffer> buffer =
named_data_map->get_data(data_name.c_str());
if (!buffer.ok()) {
Expand All @@ -198,8 +210,9 @@ const uint8_t* getConstantDataPtr(
}
const uint8_t* data_ptr =
static_cast<const uint8_t*>(buffer.get().data());
loaded_buffers_from_map.push_back(std::move(buffer.get()));
freeable_buffers.push_back(std::move(buffer.get()));
return data_ptr;
#endif
}
}
}
Expand All @@ -222,7 +235,8 @@ Error defineTensor(
std::vector<uint32_t>& output_ids,
CompileAllocator& allocator,
const NamedDataMap* named_data_map,
std::vector<FreeableBuffer>& loaded_buffers_from_map) {
std::vector<FreeableBuffer>& freeable_buffers,
XNNWeightsCache* weights_cache) {
const fb_xnnpack::XNNTensorValue* tensor_value = nullptr;
const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr;

Expand Down Expand Up @@ -264,7 +278,8 @@ Error defineTensor(
flatbuffer_graph,
constant_data_ptr,
named_data_map,
loaded_buffers_from_map);
freeable_buffers,
weights_cache);

xnn_status status;
// The type we might have to convert to
Expand Down Expand Up @@ -1999,9 +2014,9 @@ ET_NODISCARD Error XNNCompiler::compileModel(
const void* buffer_pointer,
size_t num_bytes,
XNNExecutor* executor,
MemoryAllocator* runtime_allocator,
const NamedDataMap* named_data_map,
xnn_workspace_t workspace) {
XNNWeightsCache* weights_cache,
xnn_workspace_t workspace,
const NamedDataMap* named_data_map) {
Result<XNNHeader> header = XNNHeader::Parse(buffer_pointer, num_bytes);
const uint8_t* flatbuffer_data = nullptr;
const uint8_t* constant_data = nullptr;
Expand Down Expand Up @@ -2065,11 +2080,14 @@ ET_NODISCARD Error XNNCompiler::compileModel(
// Invalid ids do not need to be remapped
remapped_ids.emplace(XNN_INVALID_VALUE_ID, XNN_INVALID_VALUE_ID);

// If weight cache is not on we hold onto all the unpacked buffers
// and we free them at the end
std::vector<FreeableBuffer> unpacked_buffers;

// External Ids for inputs and outputs
std::vector<uint32_t> input_ids;
std::vector<uint32_t> output_ids;
Error err = Error::Ok;
std::vector<FreeableBuffer> loaded_buffers_from_map;
for (auto value : *flatbuffer_graph->xvalues()) {
err = defineTensor(
subgraph.get(),
Expand All @@ -2081,7 +2099,8 @@ ET_NODISCARD Error XNNCompiler::compileModel(
output_ids,
compile_allocator,
named_data_map,
loaded_buffers_from_map);
unpacked_buffers,
weights_cache);

if (err != Error::Ok) {
return err;
Expand All @@ -2103,20 +2122,34 @@ ET_NODISCARD Error XNNCompiler::compileModel(

xnn_runtime_t runtime_ptr = nullptr;

// XNNWeightsCache if weights cache is not enabled, then XNNWeightsCache
// just manages the unpacked weights until the runtime is created.
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
ET_CHECK_OR_RETURN_ERROR(
unpacked_buffers.size() == 0,
Internal,
"Weight Cache is enabled, which means unpacked buffers should be owned by the cache");
xnn_weights_cache_t weights_cache_ptr =
weights_cache->get_num_unpacked_data() > 0 ? weights_cache->get()
: nullptr;
#else
xnn_weights_cache_t weights_cache_ptr = nullptr;
#endif

#ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
ET_CHECK_OR_RETURN_ERROR(
workspace != nullptr, Internal, "Failed to initialize XNNPACK workspace");
status = xnn_create_runtime_v4(
subgraph.get(),
/*weight_cache=*/nullptr, // TODO - support weight cache
weights_cache_ptr,
workspace,
::executorch::extension::threadpool::get_pthreadpool(),
runtime_flags,
&runtime_ptr);
#else
status = xnn_create_runtime_v3(
subgraph.get(),
/*weight_cache=*/nullptr, // TODO - support weight cache
weights_cache_ptr,
::executorch::extension::threadpool::get_pthreadpool(),
runtime_flags,
&runtime_ptr);
Expand All @@ -2128,10 +2161,25 @@ ET_NODISCARD Error XNNCompiler::compileModel(
"XNN Runtime creation failed with code: %s",
xnn_status_to_string(status));

#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
auto packed_weights_names = weights_cache->finalize_for_runtime();
ET_CHECK_OR_RETURN_ERROR(
packed_weights_names.ok(),
Internal,
"Failed to finalize weights cache after creating the xnn runtime")
#else
for (auto& buffer : unpacked_buffers) {
buffer.Free();
}
Result<std::vector<std::string>> packed_weights_names =
std::vector<std::string>();
#endif

err = executor->initialize( // NOLINT: runtime_ptr is non-null
runtime_ptr,
std::move(input_ids),
std::move(output_ids));
std::move(output_ids),
std::move(packed_weights_names.get()));

return err;
};
Expand Down
Loading
Loading