Skip to content

Commit df2ba85

Browse files
authored
[Snippets][CPU] Refactor emitter factory (#33823)
### Tickets: - N/A
1 parent cb4469d commit df2ba85

File tree

8 files changed

+482
-321
lines changed

8 files changed

+482
-321
lines changed

src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp

Lines changed: 116 additions & 118 deletions
Large diffs are not rendered by default.

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_copy_b_emitter.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <unordered_set>
1616
#include <vector>
1717

18+
#include "cache/multi_cache.h"
1819
#include "emitters/snippets/aarch64/jit_binary_call_emitter.hpp"
1920
#include "emitters/snippets/aarch64/kernel_executors/gemm_copy_b.hpp"
2021
#include "emitters/snippets/aarch64/utils.hpp"
@@ -34,10 +35,12 @@ using jit_generator = dnnl::impl::cpu::aarch64::jit_generator;
3435
using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t;
3536
using ExpressionPtr = ov::snippets::lowered::ExpressionPtr;
3637

37-
jit_gemm_copy_b_emitter::jit_gemm_copy_b_emitter(jit_generator* h,
38-
cpu_isa_t isa,
39-
const ExpressionPtr& expr,
40-
const snippets::KernelExecutorTablePtr& kernel_table)
38+
jit_gemm_copy_b_emitter::jit_gemm_copy_b_emitter(
39+
jit_generator* h,
40+
cpu_isa_t isa,
41+
const ExpressionPtr& expr,
42+
const snippets::KernelExecutorTablePtr& kernel_table,
43+
[[maybe_unused]] const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
4144
: jit_binary_call_emitter(h, isa, expr->get_live_regs()) {
4245
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
4346
const auto gemm_repack = ov::as_type_ptr<GemmCopyB>(expr->get_node());

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_copy_b_emitter.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#pragma once
66

7+
#include "cache/multi_cache.h"
78
#include "emitters/snippets/aarch64/jit_binary_call_emitter.hpp"
89
#include "emitters/snippets/aarch64/kernel_executors/gemm_copy_b.hpp"
910
#include "snippets/emitter.hpp"
@@ -14,7 +15,8 @@ class jit_gemm_copy_b_emitter : public jit_binary_call_emitter {
1415
jit_gemm_copy_b_emitter(dnnl::impl::cpu::aarch64::jit_generator* h,
1516
dnnl::impl::cpu::aarch64::cpu_isa_t isa,
1617
const ov::snippets::lowered::ExpressionPtr& expr,
17-
const snippets::KernelExecutorTablePtr& kernel_table);
18+
const snippets::KernelExecutorTablePtr& kernel_table,
19+
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache);
1820

1921
size_t get_inputs_count() const override {
2022
return 1;

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_emitter.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <unordered_set>
1414
#include <vector>
1515

16+
#include "cache/multi_cache.h"
1617
#include "emitters/snippets/aarch64/jit_binary_call_emitter.hpp"
1718
#include "emitters/snippets/aarch64/kernel_executors/gemm.hpp"
1819
#include "emitters/snippets/aarch64/utils.hpp"
@@ -38,7 +39,8 @@ using ExpressionPtr = ov::snippets::lowered::ExpressionPtr;
3839
jit_gemm_emitter::jit_gemm_emitter(jit_generator* h,
3940
cpu_isa_t isa,
4041
const ExpressionPtr& expr,
41-
const snippets::KernelExecutorTablePtr& kernel_table)
42+
const snippets::KernelExecutorTablePtr& kernel_table,
43+
[[maybe_unused]] const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
4244
: jit_binary_call_emitter(h, isa, expr->get_live_regs()) {
4345
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
4446
GemmKernelKaiConfig kernel_config;

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_gemm_emitter.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#pragma once
66

7+
#include "cache/multi_cache.h"
78
#include "emitters/snippets/aarch64/jit_binary_call_emitter.hpp"
89
#include "emitters/snippets/aarch64/kernel_executors/gemm.hpp"
910
#include "emitters/snippets/brgemm_generic.hpp"
@@ -16,7 +17,8 @@ class jit_gemm_emitter : public jit_binary_call_emitter {
1617
jit_gemm_emitter(dnnl::impl::cpu::aarch64::jit_generator* h,
1718
dnnl::impl::cpu::aarch64::cpu_isa_t isa,
1819
const ov::snippets::lowered::ExpressionPtr& expr,
19-
const snippets::KernelExecutorTablePtr& kernel_table);
20+
const snippets::KernelExecutorTablePtr& kernel_table,
21+
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache);
2022

2123
size_t get_inputs_count() const override {
2224
return 2;
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <memory>
8+
#include <set>
9+
#include <utility>
10+
#include <variant>
11+
#include <vector>
12+
13+
#include "cache/multi_cache.h"
14+
#include "openvino/core/node.hpp"
15+
#include "openvino/core/type/element_type.hpp"
16+
#include "snippets/lowered/expression.hpp"
17+
#include "snippets/target_machine.hpp"
18+
19+
namespace ov::intel_cpu {
20+
21+
/**
22+
* @brief Factory class for creating emitter instances for snippets code generation.
23+
* This template class provides a flexible way to create emitters with different ISA targets
24+
* and customizable wrapping behavior. It supports both snippets-specific emitters and CPU plugin emitters.
25+
* @tparam GetHost Callable type that returns the host generator instance
26+
* @tparam Isa ISA type (instruction set architecture) for target platform
27+
* @tparam Wrap Callable type for wrapping emitter instances with additional functionality
28+
* @tparam GetKernelExecutorTable Callable type that returns the kernel executor table (for caching emitters)
29+
*/
30+
template <typename GetHost, typename Isa, typename Wrap, typename GetKernelExecutorTable = void>
31+
class EmitterFactory {
32+
public:
33+
/**
34+
* @brief Constructs an EmitterFactory with the specified host getter, ISA, and wrapper.
35+
* @param get_host Callable that provides access to the host code generator
36+
* @param isa The target instruction set architecture
37+
* @param wrap Callable to wrap created emitter instances (e.g., for logging or instrumentation)
38+
*/
39+
EmitterFactory(GetHost get_host, Isa isa, Wrap wrap)
40+
: get_host_(std::move(get_host)),
41+
isa_(isa),
42+
wrap_(std::move(wrap)),
43+
get_kernel_executor_table_{},
44+
compiled_kernel_cache_{} {}
45+
46+
/**
47+
* @brief Constructs an EmitterFactory with caching support.
48+
* @param get_host Callable that provides access to the host code generator
49+
* @param isa The target instruction set architecture
50+
* @param wrap Callable to wrap created emitter instances
51+
* @param get_kernel_executor_table Callable that returns the current kernel executor table
52+
* @param compiled_kernel_cache Weak pointer to the compiled kernel cache
53+
*/
54+
template <typename T = GetKernelExecutorTable, typename = std::enable_if_t<!std::is_void_v<T>>>
55+
EmitterFactory(GetHost get_host,
56+
Isa isa,
57+
Wrap wrap,
58+
T get_kernel_executor_table,
59+
MultiCacheWeakPtr compiled_kernel_cache)
60+
: get_host_(std::move(get_host)),
61+
isa_(isa),
62+
wrap_(std::move(wrap)),
63+
get_kernel_executor_table_(std::move(get_kernel_executor_table)),
64+
compiled_kernel_cache_(std::move(compiled_kernel_cache)) {}
65+
66+
/**
67+
* @brief Creates a jitters_value for emitters that are constructed from a snippets expression.
68+
* This is the simple form for emitters that don't require additional constructor arguments
69+
* beyond get_host(), isa, and expr.
70+
* @tparam Emitter The emitter class type to instantiate
71+
* @return A jitters_value containing factory and precision query functions
72+
*/
73+
template <typename Emitter>
74+
[[nodiscard]] ov::snippets::jitters_value from_expr() const {
75+
return {[get_host = get_host_, isa = isa_, wrap = wrap_](
76+
const ov::snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<ov::snippets::Emitter> {
77+
auto emitter = std::make_shared<Emitter>(get_host(), isa, expr);
78+
return wrap(emitter, expr);
79+
},
80+
[](const std::shared_ptr<ov::Node>& n) -> std::set<ov::element::TypeVector> {
81+
return Emitter::get_supported_precisions(n);
82+
}};
83+
}
84+
85+
/**
86+
* @brief Creates a jitters_value for emitters that require kernel executor table and cache.
87+
* This method is used for emitters that support runtime recompilation and caching
88+
* (e.g., BRGEMM emitters). It uses the kernel_executor_table and compiled_kernel_cache
89+
* members that were provided during factory construction.
90+
* @tparam Emitter The emitter class type to instantiate
91+
* @return A jitters_value containing factory and precision query functions
92+
*/
93+
template <typename Emitter, typename T = GetKernelExecutorTable, typename = std::enable_if_t<!std::is_void_v<T>>>
94+
[[nodiscard]] ov::snippets::jitters_value from_expr_cached() const {
95+
OPENVINO_ASSERT(!compiled_kernel_cache_.expired(), "compiled_kernel_cache_ is expired in from_expr_cached");
96+
return {[get_host = get_host_,
97+
isa = isa_,
98+
wrap = wrap_,
99+
get_kernel_table = get_kernel_executor_table_,
100+
compiled_kernel_cache = compiled_kernel_cache_](
101+
const ov::snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<ov::snippets::Emitter> {
102+
auto emitter =
103+
std::make_shared<Emitter>(get_host(), isa, expr, get_kernel_table(), compiled_kernel_cache);
104+
return wrap(emitter, expr);
105+
},
106+
[](const std::shared_ptr<ov::Node>& n) -> std::set<ov::element::TypeVector> {
107+
return Emitter::get_supported_precisions(n);
108+
}};
109+
}
110+
111+
/**
112+
* @brief Creates a jitters_value for emitters that are constructed from an OpenVINO node.
113+
* This method generates a factory function for emitters that only require the host, ISA,
114+
* and the node from the expression. Unlike from_expr(), this does not apply the wrap_ callable.
115+
* @tparam Emitter The emitter class type to instantiate
116+
* @return A jitters_value containing factory and precision query functions
117+
*/
118+
template <typename Emitter>
119+
[[nodiscard]] ov::snippets::jitters_value from_node() const {
120+
return {[get_host = get_host_, isa = isa_](
121+
const ov::snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<ov::snippets::Emitter> {
122+
return std::make_shared<Emitter>(get_host(), isa, expr->get_node());
123+
},
124+
[](const std::shared_ptr<ov::Node>& n) -> std::set<ov::element::TypeVector> {
125+
return Emitter::get_supported_precisions(n);
126+
}};
127+
}
128+
129+
/**
130+
* @brief Creates a jitters_value for operations that are decomposed into low-level expressions.
131+
* This method returns a factory that produces nullptr (indicating no direct emitter implementation)
132+
* while still advertising the supported precisions for the operation. This is used for
133+
* high-level operations that are decomposed into lower-level expressions during compilation,
134+
* where the actual code generation happens at the decomposed expression level.
135+
* @param supported_precisions The set of precision combinations supported by this operation
136+
* @return A jitters_value with a null emitter factory and the provided precision information
137+
*/
138+
[[nodiscard]] static ov::snippets::jitters_value undefined(std::set<ov::element::TypeVector> supported_precisions) {
139+
return {[](const ov::snippets::lowered::ExpressionPtr&) -> std::shared_ptr<ov::snippets::Emitter> {
140+
return nullptr;
141+
},
142+
[supported_precisions = std::move(supported_precisions)](
143+
const std::shared_ptr<ov::Node>&) -> std::set<ov::element::TypeVector> {
144+
return supported_precisions;
145+
}};
146+
}
147+
148+
private:
149+
GetHost get_host_;
150+
Isa isa_;
151+
Wrap wrap_;
152+
std::conditional_t<std::is_void_v<GetKernelExecutorTable>, std::monostate, GetKernelExecutorTable>
153+
get_kernel_executor_table_;
154+
std::conditional_t<std::is_void_v<GetKernelExecutorTable>, std::monostate, MultiCacheWeakPtr>
155+
compiled_kernel_cache_;
156+
};
157+
158+
template <typename GetHost, typename Isa, typename Wrap>
159+
EmitterFactory(GetHost, Isa, Wrap) -> EmitterFactory<GetHost, Isa, Wrap, void>;
160+
161+
template <typename GetHost, typename Isa, typename Wrap, typename GetKernelExecutorTable>
162+
EmitterFactory(GetHost, Isa, Wrap, GetKernelExecutorTable, MultiCacheWeakPtr)
163+
-> EmitterFactory<GetHost, Isa, Wrap, GetKernelExecutorTable>;
164+
165+
} // namespace ov::intel_cpu

0 commit comments

Comments
 (0)