Skip to content

[GPU] Enable custom op with dynamic shape #30880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
adeb32b
draft: enable gpu to support dynamic customer op.
xipingyan Jun 5, 2025
bf563f3
wrapper calc work size for dynamic shape to update gws.
xipingyan Jun 5, 2025
133c3f8
Clone a new op to make sure original model can be released.Back
xipingyan Jun 11, 2025
129baa5
update debug log, and revert useless update.
xipingyan Jun 12, 2025
9333900
is_dynamic->is_dynamic_input
xipingyan Jun 12, 2025
d13c484
wrapper get_output_shape
xipingyan Jun 17, 2025
977877a
Move update gws,lws to primitive_imple create.
xipingyan Jun 19, 2025
026a4ee
Fix gpu unit test fail issue.
xipingyan Jun 24, 2025
1a1bf9f
Add test case for dynamic shape custom op.
xipingyan Jul 3, 2025
dcf0529
Merge branch 'master' into xp/enable_custom_op_with_dynamic_shape
xipingyan Jul 3, 2025
aed6f42
Fix test case build fail issue.
xipingyan Jul 3, 2025
59f6064
Merge branch 'master' into xp/enable_custom_op_with_dynamic_shape
xipingyan Jul 3, 2025
aa974d0
Merge branch 'master' into xp/enable_custom_op_with_dynamic_shape
xipingyan Jul 4, 2025
5a8363c
fix windows build issue.
xipingyan Jul 4, 2025
7216436
fix windows build issue.
xipingyan Jul 4, 2025
afebea2
fix unit test fail: custom_gpu_primitive_f32.add_basic_in2x2x2x2
xipingyan Jul 5, 2025
9bee3e5
Regist custom_gpu_primitive with dynamic_shape kernel.
xipingyan Jul 27, 2025
a3c26c1
1: test kernel: get index based on macro
xipingyan Jul 29, 2025
0f64fd8
Override get_shape_infer_dependencies
xipingyan Jul 29, 2025
5512c1c
Merge branch 'master' into xp/enable_custom_op_with_dynamic_shape
xipingyan Jul 30, 2025
5d0b8aa
Merge branch 'master' into xp/enable_custom_op_with_dynamic_shape
xipingyan Jul 30, 2025
dc6f17f
Merge branch 'master' into xp/enable_custom_op_with_dynamic_shape
xipingyan Jul 31, 2025
6e3dd7f
Fix ci issue.
xipingyan Aug 1, 2025
7554290
Merge branch 'xp/enable_custom_op_with_dynamic_shape' of https://gith…
xipingyan Aug 1, 2025
116f797
Merge branch 'master' into xp/enable_custom_op_with_dynamic_shape
xipingyan Aug 1, 2025
c1473b8
move generateTestFilePrefix to setup and teardown.
xipingyan Aug 1, 2025
422c3ab
Add test: custom op static model accuracy test.
xipingyan Aug 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -136,6 +136,9 @@ struct kernel_impl_params final {
return output_layouts[idx];
}

size_t get_input_layout_size() const {
return input_layouts.size();
}

bool has_fused_primitives() const { return !fused_desc.empty(); }

Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
#pragma once
#include "primitive.hpp"
#include "intel_gpu/runtime/memory.hpp"
#include "intel_gpu/plugin/simple_math.hpp"
#include <vector>
#include <string>

@@ -47,6 +48,51 @@ struct custom_gpu_primitive : public primitive_base<custom_gpu_primitive> {
}
};

static void update_work_group_size(const ov::PartialShape& dims,
int calcWgDimInputIdx,
const ov::PartialShape& inputDims,
const std::vector<std::string>& globalSizeRules,
const std::vector<std::string>& localSizeRules,
std::vector<size_t>& gws,
std::vector<size_t>& lws) {
#define GetDim(DIM) DIM.is_dynamic() ? -1 : DIM.get_length()

gws.clear();
lws.clear();

int batchDim = 0, featureDim = 0, yDim = 0, xDim = 0;
// if calcWgDimInputIdx is greater than -1, take dimension from input
if (calcWgDimInputIdx >= 0) {
xDim = static_cast<int>(GetDim(inputDims[inputDims.size() - 1]));
yDim = dims.size() > 1 ? static_cast<int>(GetDim(inputDims[inputDims.size() - 2])) : 0;
featureDim = dims.size() > 2 ? static_cast<int>(GetDim(inputDims[inputDims.size() - 3])) : 0;
batchDim = dims.size() > 3 ? static_cast<int>(GetDim(inputDims[inputDims.size() - 4])) : 0;
} else {
batchDim = (dims.size() > 0) ? GetDim(dims[0]) : 1;
featureDim = (dims.size() > 1) ? GetDim(dims[1]) : 1;
yDim = (dims.size() > 2) ? GetDim(dims[2]) : 1;
xDim = (dims.size() > 3) ? GetDim(dims[3]) : 1;
}
const std::map<char, int> vars = {
{'b', batchDim}, {'B', batchDim},
{'f', featureDim}, {'F', featureDim},
{'y', yDim}, {'Y', yDim},
{'x', xDim}, {'X', xDim},
};
for (const auto& rule : globalSizeRules) {
SimpleMathExpression expr;
expr.SetVariables(vars);
expr.SetExpression(rule);
gws.push_back(expr.Evaluate());
}
for (const auto& rule : localSizeRules) {
SimpleMathExpression expr;
expr.SetVariables(vars);
expr.SetExpression(rule);
lws.push_back(expr.Evaluate());
}
}

/// @brief Constructs custom_gpu_primitive primitive
/// @param id This primitive id.
/// @param input Input primitive ids.
@@ -65,15 +111,23 @@ struct custom_gpu_primitive : public primitive_base<custom_gpu_primitive> {
const std::string& build_options,
const layout& output_layout,
const std::vector<size_t>& gws = {},
const std::vector<size_t>& lws = {})
const std::vector<size_t>& lws = {},
const std::shared_ptr<ov::Node>& op = nullptr,
const int calcWgDimInputIdx = -1,
const std::vector<std::string> globalSizeRules = {},
const std::vector<std::string> localSizeRules = {})
: primitive_base(id, inputs, 1, {optional_data_type()}, {output_layout.data_padding}),
kernel_entry_point(kernel_entry_point),
kernel_arguments(kernel_arguments),
build_options(build_options),
output_layout(output_layout),
gws(gws.size() ? gws : std::vector<size_t>{output_layout.count()}),
lws(lws),
kernels_code(kernels_code) {}
kernels_code(kernels_code),
op(op),
calcWgDimInputIdx(calcWgDimInputIdx),
globalSizeRules(globalSizeRules),
localSizeRules(localSizeRules) {}

/// @brief The name of the entry point function in the kernel
const std::string kernel_entry_point;
@@ -89,6 +143,13 @@ struct custom_gpu_primitive : public primitive_base<custom_gpu_primitive> {
const std::vector<size_t> lws;
/// @brief Source code for the kernel
const primitive_id_arr kernels_code;
/// @brief Original IR op
const std::shared_ptr<ov::Node> op;
/// @brief -1: mean calc gws via output, else calc gws via inputs
const int calcWgDimInputIdx = -1;
/// @brief Custom provided rules for calc work sizes.
const std::vector<std::string> globalSizeRules;
const std::vector<std::string> localSizeRules;

size_t hash() const override {
size_t seed = primitive::hash();
48 changes: 41 additions & 7 deletions src/plugins/intel_gpu/src/graph/impls/ocl/custom_primitive.cpp
Original file line number Diff line number Diff line change
@@ -71,6 +71,17 @@ struct custom_gpu_primitive_impl : typed_primitive_impl<custom_gpu_primitive> {
return {kernels_cache.get_cached_kernel_id(_kernels[0])};
}

void set_kernels(cldnn::kernels_cache::compiled_kernels kernels) override {
OPENVINO_ASSERT(kernels.size() == 1, "Only the kernels of the single primitive should be allowed.");
auto& kernel_vec = kernels.begin()->second;
_kernels.clear();
_kernels.resize(kernel_vec.size());
for (auto& k : kernel_vec) {
auto sub_kernel_idx = k.second;
_kernels[sub_kernel_idx] = k.first;
}
}

void set_arguments_impl(custom_gpu_primitive_inst& instance) override {
auto& stream = instance.get_network().get_stream();
kernel_arguments_data args;
@@ -211,14 +222,16 @@ static void add_layout_to_jit(kernel_selector::jit_constants& mem_consts, const
mem_consts.AddConstant(kernel_selector::MakeJitConstant(name + "_OFFSET", std::to_string(offset)));
}

static std::string get_jit_constant(const custom_gpu_primitive_node& outer, const kernel_impl_params& impl_param) {
static std::string get_jit_constant(const custom_gpu_primitive_node& outer,
const kernel_impl_params& impl_param,
const std::vector<size_t>& gws,
const std::vector<size_t>& lws) {
kernel_selector::jit_constants mem_consts{
kernel_selector::MakeJitConstant("NUM_INPUTS", std::to_string(outer.get_dependencies().size()))};
const auto primitive = outer.get_primitive().get();

mem_consts.AddConstants({
kernel_selector::MakeJitConstant("GLOBAL_WORKSIZE", primitive->gws),
kernel_selector::MakeJitConstant("LOCAL_WORKSIZE", primitive->lws),
kernel_selector::MakeJitConstant("GLOBAL_WORKSIZE", gws),
kernel_selector::MakeJitConstant("LOCAL_WORKSIZE", lws),
});

for (size_t i = 0; i < impl_param.input_layouts.size(); i++) {
@@ -239,17 +252,38 @@ static std::string get_jit_constant(const custom_gpu_primitive_node& outer, cons
static std::unique_ptr<primitive_impl> create(const custom_gpu_primitive_node& arg, const kernel_impl_params& impl_param) {
const auto primitive = arg.get_primitive().get();

const auto& orig_output_layout = impl_param.get_output_layout();
OPENVINO_ASSERT(orig_output_layout.is_static(), "out layouts should be static for create primitive_impl!");

std::vector<size_t> gws, lws;
custom_gpu_primitive::update_work_group_size(orig_output_layout.get_partial_shape(),
primitive->calcWgDimInputIdx,
orig_output_layout.get_partial_shape(),
primitive->globalSizeRules,
primitive->localSizeRules,
gws,
lws);

if (gws.empty()) {
gws = primitive->gws;
}
if (lws.empty()) {
lws = primitive->lws;
}

auto cl_kernel = std::make_shared<kernel_selector::cl_kernel_data>();
cl_kernel->code.kernelString = std::make_shared<kernel_selector::kernel_string>();
cl_kernel->code.kernelString->entry_point = primitive->kernel_entry_point;
cl_kernel->code.kernelString->options = primitive->build_options;
cl_kernel->code.kernelString->jit = get_jit_constant(arg, impl_param);
const std::vector<size_t> const_gws = gws;
const std::vector<size_t> const_lws = lws;
cl_kernel->code.kernelString->jit = get_jit_constant(arg, impl_param, const_gws, const_lws);
for (const auto& s : primitive->kernels_code) {
cl_kernel->code.kernelString->str += s + "\n";
}

cl_kernel->params.workGroups.global = primitive->gws;
cl_kernel->params.workGroups.local = primitive->lws;
cl_kernel->params.workGroups.global = gws;
cl_kernel->params.workGroups.local = lws;

for (const auto& p : primitive->kernel_arguments) {
cl_kernel->params.arguments.push_back(get_arg(p));
Original file line number Diff line number Diff line change
@@ -5,11 +5,22 @@
#pragma once
#include "intel_gpu/primitives/custom_gpu_primitive.hpp"
#include "primitive_inst.h"
#include "openvino/op/parameter.hpp"

#include <string>

namespace cldnn {

template <>
struct typed_program_node<custom_gpu_primitive> : public typed_program_node_base<custom_gpu_primitive> {
using parent = typed_program_node_base<custom_gpu_primitive>;

public:
typed_program_node(std::shared_ptr<primitive> prim, program& prog) : parent(prim, prog) {}
program_node& input() const { return get_dependency(0); }

std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }
};
using custom_gpu_primitive_node = typed_program_node<custom_gpu_primitive>;

template <>
@@ -19,12 +30,14 @@ class typed_primitive_inst<custom_gpu_primitive> : public typed_primitive_inst_b

public:
template<typename ShapeType>
static std::vector<layout> calc_output_layouts(custom_gpu_primitive_node const& /*node*/, const kernel_impl_params& impl_param) {
static std::vector<layout> calc_output_layouts(custom_gpu_primitive_node const& node, const kernel_impl_params& impl_param) {
assert(static_cast<bool>(impl_param.desc->output_data_types[0]) == false &&
"Output data type forcing is not supported for "
"custom_gpu_primitive_node!");
layout output_layout = impl_param.typed_desc<custom_gpu_primitive>()->output_layout;

typed_primitive_inst<custom_gpu_primitive>::update_output_shape(impl_param, output_layout);

// if the output layout format was set to any, it means the layer output format will be the same as the first input
if (output_layout.format == format::any) {
output_layout.format = impl_param.get_input_layout().format;
@@ -38,6 +51,8 @@ class typed_primitive_inst<custom_gpu_primitive> : public typed_primitive_inst_b
"custom_gpu_primitive_node!");
layout output_layout = impl_param.typed_desc<custom_gpu_primitive>()->output_layout;

typed_primitive_inst<custom_gpu_primitive>::update_output_shape(impl_param, output_layout);

// if the output layout format was set to any, it means the layer output format will be the same as the first
// input
if (output_layout.format == format::any) {
@@ -50,6 +65,34 @@ class typed_primitive_inst<custom_gpu_primitive> : public typed_primitive_inst_b

public:
typed_primitive_inst(network& network, custom_gpu_primitive_node const& node);

private:
static void update_output_shape(const kernel_impl_params& impl_param, layout& output_layout) {
bool is_dynamic_input = false;
const auto inp_sz = impl_param.get_input_layout_size();
for (size_t i = 0; i < inp_sz; i++) {
if (impl_param.get_input_layout(i).is_dynamic()) {
is_dynamic_input = true;
break;
}
}

// Execute the op's shape inference only for dynamic node when input shapes have already been calculated; otherwise, keep the original output layout
// unchanged (it will be either static for static model or have dynamic shape in case of dynamic flow)
if (!is_dynamic_input && output_layout.is_dynamic()) {
ov::OutputVector new_inputs;
for (size_t i = 0; i < inp_sz; i++) {
auto input = std::make_shared<ov::op::v0::Parameter>(impl_param.get_input_layout(i).data_type, impl_param.get_input_layout(i).get_shape());
new_inputs.emplace_back(input);
}

auto op = impl_param.typed_desc<custom_gpu_primitive>()->op;
auto new_op = op->clone_with_new_inputs(new_inputs);
new_op->validate_and_infer_types();
auto new_outp_shape = new_op->get_output_shape(0);
output_layout.set_partial_shape(new_outp_shape);
}
}
};

using custom_gpu_primitive_inst = typed_primitive_inst<custom_gpu_primitive>;
80 changes: 39 additions & 41 deletions src/plugins/intel_gpu/src/plugin/ops/custom.cpp
Original file line number Diff line number Diff line change
@@ -169,65 +169,63 @@ void CreateCustomOp(ProgramBuilder& p, const std::shared_ptr<ov::Node>& op, Cust
const std::string layerTitle("\n// Layer " + op->get_friendly_name() + " using Custom Layer " + customLayer->Name() + "\n");
const std::string defineTitle("// Custom Layer User Defines\n");

auto dims = op->get_output_shape(0);
size_t N = (dims.size() > 0) ? dims[0] : 1;
size_t C = (dims.size() > 1) ? dims[1] : 1;
size_t H = (dims.size() > 2) ? dims[2] : 1;
size_t W = (dims.size() > 3) ? dims[3] : 1;
cldnn::tensor outputTensor = cldnn::tensor(cldnn::batch(N), cldnn::feature(C), cldnn::spatial(W, H));
auto dims = op->get_output_partial_shape(0);
int iidx = customLayer->InputDimSourceIndex();

cldnn::layout outputLayout = cldnn::layout(cldnn::element_type_to_data_type(op->get_output_element_type(0)), outputFormat, outputTensor);
size_t N = (dims.size() > 0) ? dims[0].is_dynamic() ? -1 : dims[0].get_length() : 1;
size_t C = (dims.size() > 1) ? dims[1].is_dynamic() ? -1 : dims[1].get_length() : 1;
size_t H = (dims.size() > 2) ? dims[2].is_dynamic() ? -1 : dims[2].get_length() : 1;
size_t W = (dims.size() > 3) ? dims[3].is_dynamic() ? -1 : dims[3].get_length() : 1;

// evaluate work sizes rules
std::vector<size_t> gws, lws;
cldnn::layout outputLayout;
if (dims.is_dynamic()) {
outputLayout = cldnn::layout(dims, cldnn::element_type_to_data_type(op->get_output_element_type(0)), outputFormat);
} else {
cldnn::tensor outputTensor = cldnn::tensor(cldnn::batch(N), cldnn::feature(C), cldnn::spatial(W, H));
outputLayout = cldnn::layout(cldnn::element_type_to_data_type(op->get_output_element_type(0)), outputFormat, outputTensor);
}

// assume output tensor is dimension source by default
int batchDim = outputTensor.batch[0];
int featureDim = outputTensor.feature[0];
int yDim = outputTensor.spatial[1];
int xDim = outputTensor.spatial[0];
int iidx = customLayer->InputDimSourceIndex();
std::vector<size_t> gws, lws;

std::string genericLayerName = layer_type_name_ID(op);
// if input index is greater than -1, take dimension from input
if (iidx >= 0) {
if (static_cast<size_t>(iidx) >= op->get_input_size())
OPENVINO_THROW("Invalid input tensor for index: ", iidx);
auto inputDims = op->get_input_shape(iidx);
cldnn::custom_gpu_primitive::update_work_group_size(dims, iidx, inputDims, customLayer->GlobalSizeRules(), customLayer->LocalSizeRules(), gws, lws);
} else {
cldnn::custom_gpu_primitive::update_work_group_size(dims,
iidx,
ov::PartialShape(),
customLayer->GlobalSizeRules(),
customLayer->LocalSizeRules(),
gws,
lws);
}

std::string genericLayerName = layer_type_name_ID(op);

xDim = static_cast<int>(inputDims[inputDims.size() - 1]);
yDim = dims.size() > 1 ? static_cast<int>(inputDims[inputDims.size() - 2]) : 0;
featureDim = dims.size() > 2 ? static_cast<int>(inputDims[inputDims.size() - 3]) : 0;
batchDim = dims.size() > 3 ? static_cast<int>(inputDims[inputDims.size() - 4]) : 0;
}
const std::map<char, int> vars = {
{ 'b', batchDim } , { 'B', batchDim },
{ 'f', featureDim }, { 'F', featureDim },
{ 'y', yDim }, { 'Y', yDim },
{ 'x', xDim }, { 'X', xDim },
};
for (const auto& rule : customLayer->GlobalSizeRules()) {
SimpleMathExpression expr;
expr.SetVariables(vars);
expr.SetExpression(rule);
gws.push_back(expr.Evaluate());
}
for (const auto& rule : customLayer->LocalSizeRules()) {
SimpleMathExpression expr;
expr.SetVariables(vars);
expr.SetExpression(rule);
lws.push_back(expr.Evaluate());
// Clone a new op to make sure original model can be released.
ov::OutputVector new_inputs;
for (size_t i = 0; i < op->get_input_size(); i++) {
auto input = std::make_shared<ov::op::v0::Parameter>(op->get_input_element_type(i), op->get_input_partial_shape(i));
new_inputs.emplace_back(input);
}
std::shared_ptr<ov::Node> op_bk = op->clone_with_new_inputs(new_inputs);

auto customPrim = cldnn::custom_gpu_primitive(genericLayerName,
reordered_inputs,
{ layerTitle, defineTitle, layerDefines, customLayer->KernelSource() },
{layerTitle, defineTitle, layerDefines, customLayer->KernelSource()},
customLayer->KernelEntry(),
kernelParameters,
customLayer->CompilerOptions(),
outputLayout,
gws,
lws);
lws,
op_bk,
iidx,
customLayer->GlobalSizeRules(),
customLayer->LocalSizeRules());
p.add_primitive(*op, customPrim);

auto prevLayerName = genericLayerName;
@@ -236,7 +234,7 @@ void CreateCustomOp(ProgramBuilder& p, const std::shared_ptr<ov::Node>& op, Cust
auto reorderPrimName = genericLayerName + ProgramBuilder::m_postCustomLayerTag;
p.add_primitive(*op, cldnn::reorder(reorderPrimName,
cldnn::input_info(genericLayerName),
cldnn::format::get_default_format(op->get_output_shape(0).size()),
cldnn::format::get_default_format(op->get_output_partial_shape(0).size()),
customPrim.output_layout.data_type));
prevLayerName = reorderPrimName;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <string>
#include <vector>

#include "openvino/op/constant.hpp"
#include "openvino/runtime/core.hpp"
#include "openvino/runtime/exec_model_info.hpp"
#include "openvino/runtime/properties.hpp"
#include "shared_test_classes/base/ov_behavior_test_utils.hpp"

using namespace ::testing;

namespace ov {
namespace test {
namespace intel_gpu {

class CustomAddOp : public ov::op::Op {
private:
float m_alpha;
float m_beta;

public:
OPENVINO_OP("CustomAddOp", "gpu_opset");

CustomAddOp() = default;

CustomAddOp(const ov::Output<ov::Node>& input, float alpha, float beta) : Op({input}), m_alpha(alpha), m_beta(beta) {
constructor_validate_and_infer_types();
}

void validate_and_infer_types() override {
set_output_size(1);
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}

bool visit_attributes(ov::AttributeVisitor& visitor) override {
visitor.on_attribute("alpha", m_alpha);
visitor.on_attribute("beta", m_beta);
return true;
}

std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override {
OPENVINO_ASSERT(new_args.size() == 1, "Incorrect number of new arguments");
return std::make_shared<CustomAddOp>(new_args[0], m_alpha, m_beta);
}

bool has_evaluate() const override {
return true;
}

bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override {
auto in = inputs[0];
auto out = outputs[0];
out.set_shape(in.get_shape());
for (size_t i = 0; i < out.get_size(); i++) {
out.data<float>()[i] = in.data<float>()[i] * m_alpha + m_beta;
}
return true;
}
};

using CustomOpDynamicTestParams = std::tuple<std::vector<ov::Shape>, // input shape
std::vector<std::vector<float>>>; // input data
class CustomOpDynamic : public ov::test::TestsCommon, public testing::WithParamInterface<CustomOpDynamicTestParams> {
void SetUp() override {
generate_config_files();
};

void TearDown() override {
ov::test::utils::removeFile(config_cl);
ov::test::utils::removeFile(config_xml);
}

public:
static std::string getTestCaseName(const testing::TestParamInfo<CustomOpDynamicTestParams>& obj) {
std::vector<ov::Shape> input_shapes;
std::vector<std::vector<float>> input_datas;
std::tie(input_shapes, input_datas) = obj.param;

std::ostringstream result;
result << "input_shape=";
for (auto shape : input_shapes) {
result << shape;
}
return result.str();
}

static const size_t dim1 = 1;
void run() {
std::vector<ov::Shape> input_shapes;
std::vector<std::vector<float>> input_datas;
std::tie(input_shapes, input_datas) = GetParam();
ASSERT_TRUE(input_shapes.size() == input_datas.size());

ov::Core core;
float alpha = 1.0, beta = 0.1;
auto model = generate_model_with_custom_add_op(alpha, beta, ov::PartialShape{-1, dim1, -1});

ov::AnyMap config = {ov::hint::inference_precision(ov::element::f32), {"CONFIG_FILE", config_xml}};
auto compiled_model = core.compile_model(model, ov::test::utils::DEVICE_GPU, config);

auto runtime_graph = compiled_model.get_runtime_model();
auto ops = runtime_graph->get_ordered_ops();

bool found_custom_op = false;
for (auto op : ops) {
if (op->get_rt_info()[ov::exec_model_info::LAYER_TYPE].as<std::string>() == "CustomGPUPrimitive") {
found_custom_op = true;
break;
}
}
ASSERT_TRUE(found_custom_op);

auto ireq = compiled_model.create_infer_request();
for (size_t i = 0; i < input_datas.size(); i++) {
auto input = ov::Tensor({ov::element::f32}, input_shapes[i], input_datas[i].data());
ireq.set_input_tensor(0, input);
ireq.infer();
auto output = ireq.get_output_tensor(0);
std::vector<float> actual(output.data<float>(), output.data<float>() + output.get_size());

ASSERT_EQ(output.get_element_type(), element::f32);

float* inp_data = input.data<float>();
for (size_t i = 0; i < output.get_size(); i++) {
ASSERT_FLOAT_EQ(actual[i], inp_data[i] * alpha + beta);
}
}
}

protected:
std::string config_cl;
std::string config_xml;

void generate_config_files() {
config_cl = ov::test::utils::generateTestFilePrefix() + "_custom_op_dynamic.cl";
config_xml = ov::test::utils::generateTestFilePrefix() + "_custom_op_dynamic.xml";

std::string content_cl = R"(
__kernel void custom_add_kernel(
__global const INPUT0_TYPE* inp0,
__global OUTPUT0_TYPE* outp) {
const uint b = (uint)get_global_id(0);
const uint f = (uint)get_global_id(1);
const uint y = (uint)get_global_id(2);
#if INPUT0_DIMS_SIZE == 4
const uint x = 0;
#endif
const unsigned src_index = b*INPUT0_DIMS[1]*INPUT0_DIMS[2]*INPUT0_DIMS[3] + f*INPUT0_DIMS[2]*INPUT0_DIMS[3] + y*INPUT0_DIMS[3] + x;
const unsigned dst_index = src_index;
outp[dst_index] = inp0[src_index] * alpha + beta;
})";

std::string content_xml = R"(
<CustomLayer name="CustomAddOp" type="SimpleGPU" version="1">
<Kernel entry="custom_add_kernel">
<Source filename=")" + config_cl + R"("/>
<Define name="alpha" type="float" param="alpha" default="1.0"/>
<Define name="beta" type="float" param="beta" default="0.1"/>
</Kernel>
<Buffers>
<Tensor arg-index="0" type="input" port-index="0" format="BFYX"/>
<Tensor arg-index="1" type="output" port-index="0" format="BFYX"/>
</Buffers>
<CompilerOptions options="-cl-mad-enable"/>
<WorkSizes global="B,F,Y"/>
</CustomLayer>)";

ov::test::utils::createFile(config_cl, content_cl);
ov::test::utils::createFile(config_xml, content_xml);
}

std::shared_ptr<ov::Model> generate_model_with_custom_add_op(float alpha, float beta, ov::PartialShape inp_shape) {
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, inp_shape);
auto op = std::make_shared<CustomAddOp>(input, alpha, beta);
auto result = std::make_shared<ov::op::v0::Result>(op);
return std::make_shared<ov::Model>(ov::ResultVector{result}, ov::ParameterVector{input}, "model_with_custom_op_dynamic");
}
};

class CustomOpStatic : public CustomOpDynamic {
public:
void run() {
std::vector<ov::Shape> input_shapes;
std::vector<std::vector<float>> input_datas;
std::tie(input_shapes, input_datas) = GetParam();
ASSERT_EQ(input_shapes.size(), input_datas.size());
ASSERT_EQ(input_shapes.size(), 1u);

ov::Core core;
float alpha = 1.0, beta = 0.1;
auto model = generate_model_with_custom_add_op(alpha, beta, ov::PartialShape(input_shapes[0]));

ov::AnyMap config = {ov::hint::inference_precision(ov::element::f32), {"CONFIG_FILE", config_xml}};
auto compiled_model = core.compile_model(model, ov::test::utils::DEVICE_GPU, config);

auto runtime_graph = compiled_model.get_runtime_model();
auto ops = runtime_graph->get_ordered_ops();

bool found_custom_op = false;
for (auto op : ops) {
if (op->get_rt_info()[ov::exec_model_info::LAYER_TYPE].as<std::string>() == "CustomGPUPrimitive") {
found_custom_op = true;
break;
}
}
ASSERT_TRUE(found_custom_op);

auto ireq = compiled_model.create_infer_request();
auto input = ov::Tensor({ov::element::f32}, input_shapes[0], input_datas[0].data());
ireq.set_input_tensor(0, input);
ireq.infer();
auto output = ireq.get_output_tensor(0);
std::vector<float> actual(output.data<float>(), output.data<float>() + output.get_size());

ASSERT_EQ(output.get_element_type(), element::f32);

float* inp_data = input.data<float>();
for (size_t i = 0; i < output.get_size(); i++) {
ASSERT_FLOAT_EQ(actual[i], inp_data[i] * alpha + beta);
}
}
};

TEST_P(CustomOpDynamic, Accuracy) {
run();
}

TEST_P(CustomOpStatic, Accuracy) {
run();
}

const std::vector<ov::Shape> input_shapes{{1, CustomOpDynamic::dim1, 2}, {2, CustomOpDynamic::dim1, 3}};
const std::vector<std::vector<float>> input_datas{{0.2, 0.4}, {0.2, 0.4, 0.3, 0.5, 0.7, 0.9}};

INSTANTIATE_TEST_SUITE_P(smoke_GPU_Accuracy,
CustomOpDynamic,
::testing::Combine(::testing::Values(input_shapes), ::testing::Values(input_datas)),
CustomOpDynamic::getTestCaseName);

const std::vector<ov::Shape> input_static_shapes{{2, 2, 3}};
const std::vector<std::vector<float>> input_static_datas{{0.2, 0.4, 0.3, 0.5, 0.7, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6}};

INSTANTIATE_TEST_SUITE_P(smoke_GPU_Accuracy,
CustomOpStatic,
::testing::Combine(::testing::Values(input_static_shapes), ::testing::Values(input_static_datas)),
CustomOpStatic::getTestCaseName);

} // namespace intel_gpu
} // namespace test
} // namespace ov
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/tests/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@ file(GLOB_RECURSE SOURCES_MAIN
"${CMAKE_HOME_DIRECTORY}/src/plugins/intel_gpu/src/plugin/remote_tensor.cpp"
"${CMAKE_HOME_DIRECTORY}/src/plugins/intel_gpu/src/plugin/usm_host_tensor.cpp"
"${CMAKE_HOME_DIRECTORY}/src/plugins/intel_gpu/src/plugin/common_utils.cpp"
"${CMAKE_HOME_DIRECTORY}/src/plugins/intel_gpu/src/plugin/simple_math.cpp"
)

if (NOT ENABLE_ONEDNN_FOR_GPU)