diff --git a/src/plugins/intel_gpu/include/intel_gpu/graph/kernel_impl_params.hpp b/src/plugins/intel_gpu/include/intel_gpu/graph/kernel_impl_params.hpp index bb4bc6bed2ecea..bd1798f6beb4e4 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/graph/kernel_impl_params.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/graph/kernel_impl_params.hpp @@ -136,6 +136,9 @@ struct kernel_impl_params final { return output_layouts[idx]; } + size_t get_input_layout_size() const { + return input_layouts.size(); + } bool has_fused_primitives() const { return !fused_desc.empty(); } diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/custom_gpu_primitive.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/custom_gpu_primitive.hpp index e086cfb13dadbe..97650edf0a2088 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/custom_gpu_primitive.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/custom_gpu_primitive.hpp @@ -5,6 +5,7 @@ #pragma once #include "primitive.hpp" #include "intel_gpu/runtime/memory.hpp" +#include "intel_gpu/plugin/simple_math.hpp" #include #include @@ -47,6 +48,51 @@ struct custom_gpu_primitive : public primitive_base { } }; + static void update_work_group_size(const ov::PartialShape& dims, + int calcWgDimInputIdx, + const ov::PartialShape& inputDims, + const std::vector& globalSizeRules, + const std::vector& localSizeRules, + std::vector& gws, + std::vector& lws) { +#define GetDim(DIM) DIM.is_dynamic() ? -1 : DIM.get_length() + + gws.clear(); + lws.clear(); + + int batchDim = 0, featureDim = 0, yDim = 0, xDim = 0; + // if calcWgDimInputIdx is greater than -1, take dimension from input + if (calcWgDimInputIdx >= 0) { + xDim = static_cast(GetDim(inputDims[inputDims.size() - 1])); + yDim = dims.size() > 1 ? static_cast(GetDim(inputDims[inputDims.size() - 2])) : 0; + featureDim = dims.size() > 2 ? static_cast(GetDim(inputDims[inputDims.size() - 3])) : 0; + batchDim = dims.size() > 3 ? static_cast(GetDim(inputDims[inputDims.size() - 4])) : 0; + } else { + batchDim = (dims.size() > 0) ? GetDim(dims[0]) : 1; + featureDim = (dims.size() > 1) ? GetDim(dims[1]) : 1; + yDim = (dims.size() > 2) ? GetDim(dims[2]) : 1; + xDim = (dims.size() > 3) ? GetDim(dims[3]) : 1; + } + const std::map vars = { + {'b', batchDim}, {'B', batchDim}, + {'f', featureDim}, {'F', featureDim}, + {'y', yDim}, {'Y', yDim}, + {'x', xDim}, {'X', xDim}, + }; + for (const auto& rule : globalSizeRules) { + SimpleMathExpression expr; + expr.SetVariables(vars); + expr.SetExpression(rule); + gws.push_back(expr.Evaluate()); + } + for (const auto& rule : localSizeRules) { + SimpleMathExpression expr; + expr.SetVariables(vars); + expr.SetExpression(rule); + lws.push_back(expr.Evaluate()); + } + } + /// @brief Constructs custom_gpu_primitive primitive /// @param id This primitive id. /// @param input Input primitive ids. @@ -65,7 +111,11 @@ struct custom_gpu_primitive : public primitive_base { const std::string& build_options, const layout& output_layout, const std::vector& gws = {}, - const std::vector& lws = {}) + const std::vector& lws = {}, + const std::shared_ptr& op = nullptr, + const int calcWgDimInputIdx = -1, + const std::vector globalSizeRules = {}, + const std::vector localSizeRules = {}) : primitive_base(id, inputs, 1, {optional_data_type()}, {output_layout.data_padding}), kernel_entry_point(kernel_entry_point), kernel_arguments(kernel_arguments), @@ -73,7 +123,11 @@ struct custom_gpu_primitive : public primitive_base { output_layout(output_layout), gws(gws.size() ? gws : std::vector{output_layout.count()}), lws(lws), - kernels_code(kernels_code) {} + kernels_code(kernels_code), + op(op), + calcWgDimInputIdx(calcWgDimInputIdx), + globalSizeRules(globalSizeRules), + localSizeRules(localSizeRules) {} /// @brief The name of the entry point function in the kernel const std::string kernel_entry_point; @@ -89,6 +143,13 @@ struct custom_gpu_primitive : public primitive_base { const std::vector lws; /// @brief Source code for the kernel const primitive_id_arr kernels_code; + /// @brief Original IR op + const std::shared_ptr op; + /// @brief -1: mean calc gws via output, else calc gws via inputs + const int calcWgDimInputIdx = -1; + /// @brief Custom provided rules for calc work sizes. + const std::vector globalSizeRules; + const std::vector localSizeRules; size_t hash() const override { size_t seed = primitive::hash(); diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/custom_primitive.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/custom_primitive.cpp index 4aeadc7a297da3..0f583a849fb754 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/custom_primitive.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/custom_primitive.cpp @@ -71,6 +71,17 @@ struct custom_gpu_primitive_impl : typed_primitive_impl { return {kernels_cache.get_cached_kernel_id(_kernels[0])}; } + void set_kernels(cldnn::kernels_cache::compiled_kernels kernels) override { + OPENVINO_ASSERT(kernels.size() == 1, "Only the kernels of the single primitive should be allowed."); + auto& kernel_vec = kernels.begin()->second; + _kernels.clear(); + _kernels.resize(kernel_vec.size()); + for (auto& k : kernel_vec) { + auto sub_kernel_idx = k.second; + _kernels[sub_kernel_idx] = k.first; + } + } + void set_arguments_impl(custom_gpu_primitive_inst& instance) override { auto& stream = instance.get_network().get_stream(); kernel_arguments_data args; @@ -211,14 +222,16 @@ static void add_layout_to_jit(kernel_selector::jit_constants& mem_consts, const mem_consts.AddConstant(kernel_selector::MakeJitConstant(name + "_OFFSET", std::to_string(offset))); } -static std::string get_jit_constant(const custom_gpu_primitive_node& outer, const kernel_impl_params& impl_param) { +static std::string get_jit_constant(const custom_gpu_primitive_node& outer, + const kernel_impl_params& impl_param, + const std::vector& gws, + const std::vector& lws) { kernel_selector::jit_constants mem_consts{ kernel_selector::MakeJitConstant("NUM_INPUTS", std::to_string(outer.get_dependencies().size()))}; - const auto primitive = outer.get_primitive().get(); mem_consts.AddConstants({ - kernel_selector::MakeJitConstant("GLOBAL_WORKSIZE", primitive->gws), - kernel_selector::MakeJitConstant("LOCAL_WORKSIZE", primitive->lws), + kernel_selector::MakeJitConstant("GLOBAL_WORKSIZE", gws), + kernel_selector::MakeJitConstant("LOCAL_WORKSIZE", lws), }); for (size_t i = 0; i < impl_param.input_layouts.size(); i++) { @@ -239,17 +252,38 @@ static std::string get_jit_constant(const custom_gpu_primitive_node& outer, cons static std::unique_ptr create(const custom_gpu_primitive_node& arg, const kernel_impl_params& impl_param) { const auto primitive = arg.get_primitive().get(); + const auto& orig_output_layout = impl_param.get_output_layout(); + OPENVINO_ASSERT(orig_output_layout.is_static(), "out layouts should be static for create primitive_impl!"); + + std::vector gws, lws; + custom_gpu_primitive::update_work_group_size(orig_output_layout.get_partial_shape(), + primitive->calcWgDimInputIdx, + orig_output_layout.get_partial_shape(), + primitive->globalSizeRules, + primitive->localSizeRules, + gws, + lws); + + if (gws.empty()) { + gws = primitive->gws; + } + if (lws.empty()) { + lws = primitive->lws; + } + auto cl_kernel = std::make_shared(); cl_kernel->code.kernelString = std::make_shared(); cl_kernel->code.kernelString->entry_point = primitive->kernel_entry_point; cl_kernel->code.kernelString->options = primitive->build_options; - cl_kernel->code.kernelString->jit = get_jit_constant(arg, impl_param); + const std::vector const_gws = gws; + const std::vector const_lws = lws; + cl_kernel->code.kernelString->jit = get_jit_constant(arg, impl_param, const_gws, const_lws); for (const auto& s : primitive->kernels_code) { cl_kernel->code.kernelString->str += s + "\n"; } - cl_kernel->params.workGroups.global = primitive->gws; - cl_kernel->params.workGroups.local = primitive->lws; + cl_kernel->params.workGroups.global = gws; + cl_kernel->params.workGroups.local = lws; for (const auto& p : primitive->kernel_arguments) { cl_kernel->params.arguments.push_back(get_arg(p)); diff --git a/src/plugins/intel_gpu/src/graph/include/custom_gpu_primitive_inst.h b/src/plugins/intel_gpu/src/graph/include/custom_gpu_primitive_inst.h index 365fdc774b8f54..333595d8851ee6 100644 --- a/src/plugins/intel_gpu/src/graph/include/custom_gpu_primitive_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/custom_gpu_primitive_inst.h @@ -5,11 +5,22 @@ #pragma once #include "intel_gpu/primitives/custom_gpu_primitive.hpp" #include "primitive_inst.h" +#include "openvino/op/parameter.hpp" #include namespace cldnn { +template <> +struct typed_program_node : public typed_program_node_base { + using parent = typed_program_node_base; + +public: + typed_program_node(std::shared_ptr prim, program& prog) : parent(prim, prog) {} + program_node& input() const { return get_dependency(0); } + + std::vector get_shape_infer_dependencies() const override { return {}; } +}; using custom_gpu_primitive_node = typed_program_node; template <> @@ -19,12 +30,14 @@ class typed_primitive_inst : public typed_primitive_inst_b public: template - static std::vector calc_output_layouts(custom_gpu_primitive_node const& /*node*/, const kernel_impl_params& impl_param) { + static std::vector calc_output_layouts(custom_gpu_primitive_node const& node, const kernel_impl_params& impl_param) { assert(static_cast(impl_param.desc->output_data_types[0]) == false && "Output data type forcing is not supported for " "custom_gpu_primitive_node!"); layout output_layout = impl_param.typed_desc()->output_layout; + typed_primitive_inst::update_output_shape(impl_param, output_layout); + // if the output layout format was set to any, it means the layer output format will be the same as the first input if (output_layout.format == format::any) { output_layout.format = impl_param.get_input_layout().format; @@ -38,6 +51,8 @@ class typed_primitive_inst : public typed_primitive_inst_b "custom_gpu_primitive_node!"); layout output_layout = impl_param.typed_desc()->output_layout; + typed_primitive_inst::update_output_shape(impl_param, output_layout); + // if the output layout format was set to any, it means the layer output format will be the same as the first // input if (output_layout.format == format::any) { @@ -50,6 +65,34 @@ class typed_primitive_inst : public typed_primitive_inst_b public: typed_primitive_inst(network& network, custom_gpu_primitive_node const& node); + +private: + static void update_output_shape(const kernel_impl_params& impl_param, layout& output_layout) { + bool is_dynamic_input = false; + const auto inp_sz = impl_param.get_input_layout_size(); + for (size_t i = 0; i < inp_sz; i++) { + if (impl_param.get_input_layout(i).is_dynamic()) { + is_dynamic_input = true; + break; + } + } + + // Execute the op's shape inference only for dynamic node when input shapes have already been calculated; otherwise, keep the original output layout + // unchanged (it will be either static for static model or have dynamic shape in case of dynamic flow) + if (!is_dynamic_input && output_layout.is_dynamic()) { + ov::OutputVector new_inputs; + for (size_t i = 0; i < inp_sz; i++) { + auto input = std::make_shared(impl_param.get_input_layout(i).data_type, impl_param.get_input_layout(i).get_shape()); + new_inputs.emplace_back(input); + } + + auto op = impl_param.typed_desc()->op; + auto new_op = op->clone_with_new_inputs(new_inputs); + new_op->validate_and_infer_types(); + auto new_outp_shape = new_op->get_output_shape(0); + output_layout.set_partial_shape(new_outp_shape); + } + } }; using custom_gpu_primitive_inst = typed_primitive_inst; diff --git a/src/plugins/intel_gpu/src/plugin/ops/custom.cpp b/src/plugins/intel_gpu/src/plugin/ops/custom.cpp index 84c54da01786c6..ca0773a57d9f7c 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/custom.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/custom.cpp @@ -169,65 +169,63 @@ void CreateCustomOp(ProgramBuilder& p, const std::shared_ptr& op, Cust const std::string layerTitle("\n// Layer " + op->get_friendly_name() + " using Custom Layer " + customLayer->Name() + "\n"); const std::string defineTitle("// Custom Layer User Defines\n"); - auto dims = op->get_output_shape(0); - size_t N = (dims.size() > 0) ? dims[0] : 1; - size_t C = (dims.size() > 1) ? dims[1] : 1; - size_t H = (dims.size() > 2) ? dims[2] : 1; - size_t W = (dims.size() > 3) ? dims[3] : 1; - cldnn::tensor outputTensor = cldnn::tensor(cldnn::batch(N), cldnn::feature(C), cldnn::spatial(W, H)); + auto dims = op->get_output_partial_shape(0); + int iidx = customLayer->InputDimSourceIndex(); - cldnn::layout outputLayout = cldnn::layout(cldnn::element_type_to_data_type(op->get_output_element_type(0)), outputFormat, outputTensor); + size_t N = (dims.size() > 0) ? dims[0].is_dynamic() ? -1 : dims[0].get_length() : 1; + size_t C = (dims.size() > 1) ? dims[1].is_dynamic() ? -1 : dims[1].get_length() : 1; + size_t H = (dims.size() > 2) ? dims[2].is_dynamic() ? -1 : dims[2].get_length() : 1; + size_t W = (dims.size() > 3) ? dims[3].is_dynamic() ? -1 : dims[3].get_length() : 1; - // evaluate work sizes rules - std::vector gws, lws; + cldnn::layout outputLayout; + if (dims.is_dynamic()) { + outputLayout = cldnn::layout(dims, cldnn::element_type_to_data_type(op->get_output_element_type(0)), outputFormat); + } else { + cldnn::tensor outputTensor = cldnn::tensor(cldnn::batch(N), cldnn::feature(C), cldnn::spatial(W, H)); + outputLayout = cldnn::layout(cldnn::element_type_to_data_type(op->get_output_element_type(0)), outputFormat, outputTensor); + } - // assume output tensor is dimension source by default - int batchDim = outputTensor.batch[0]; - int featureDim = outputTensor.feature[0]; - int yDim = outputTensor.spatial[1]; - int xDim = outputTensor.spatial[0]; - int iidx = customLayer->InputDimSourceIndex(); + std::vector gws, lws; - std::string genericLayerName = layer_type_name_ID(op); // if input index is greater than -1, take dimension from input if (iidx >= 0) { if (static_cast(iidx) >= op->get_input_size()) OPENVINO_THROW("Invalid input tensor for index: ", iidx); auto inputDims = op->get_input_shape(iidx); + cldnn::custom_gpu_primitive::update_work_group_size(dims, iidx, inputDims, customLayer->GlobalSizeRules(), customLayer->LocalSizeRules(), gws, lws); + } else { + cldnn::custom_gpu_primitive::update_work_group_size(dims, + iidx, + ov::PartialShape(), + customLayer->GlobalSizeRules(), + customLayer->LocalSizeRules(), + gws, + lws); + } + + std::string genericLayerName = layer_type_name_ID(op); - xDim = static_cast(inputDims[inputDims.size() - 1]); - yDim = dims.size() > 1 ? static_cast(inputDims[inputDims.size() - 2]) : 0; - featureDim = dims.size() > 2 ? static_cast(inputDims[inputDims.size() - 3]) : 0; - batchDim = dims.size() > 3 ? static_cast(inputDims[inputDims.size() - 4]) : 0; - } - const std::map vars = { - { 'b', batchDim } , { 'B', batchDim }, - { 'f', featureDim }, { 'F', featureDim }, - { 'y', yDim }, { 'Y', yDim }, - { 'x', xDim }, { 'X', xDim }, - }; - for (const auto& rule : customLayer->GlobalSizeRules()) { - SimpleMathExpression expr; - expr.SetVariables(vars); - expr.SetExpression(rule); - gws.push_back(expr.Evaluate()); - } - for (const auto& rule : customLayer->LocalSizeRules()) { - SimpleMathExpression expr; - expr.SetVariables(vars); - expr.SetExpression(rule); - lws.push_back(expr.Evaluate()); + // Clone a new op to make sure original model can be released. + ov::OutputVector new_inputs; + for (size_t i = 0; i < op->get_input_size(); i++) { + auto input = std::make_shared(op->get_input_element_type(i), op->get_input_partial_shape(i)); + new_inputs.emplace_back(input); } + std::shared_ptr op_bk = op->clone_with_new_inputs(new_inputs); auto customPrim = cldnn::custom_gpu_primitive(genericLayerName, reordered_inputs, - { layerTitle, defineTitle, layerDefines, customLayer->KernelSource() }, + {layerTitle, defineTitle, layerDefines, customLayer->KernelSource()}, customLayer->KernelEntry(), kernelParameters, customLayer->CompilerOptions(), outputLayout, gws, - lws); + lws, + op_bk, + iidx, + customLayer->GlobalSizeRules(), + customLayer->LocalSizeRules()); p.add_primitive(*op, customPrim); auto prevLayerName = genericLayerName; @@ -236,7 +234,7 @@ void CreateCustomOp(ProgramBuilder& p, const std::shared_ptr& op, Cust auto reorderPrimName = genericLayerName + ProgramBuilder::m_postCustomLayerTag; p.add_primitive(*op, cldnn::reorder(reorderPrimName, cldnn::input_info(genericLayerName), - cldnn::format::get_default_format(op->get_output_shape(0).size()), + cldnn::format::get_default_format(op->get_output_partial_shape(0).size()), customPrim.output_layout.data_type)); prevLayerName = reorderPrimName; } diff --git a/src/plugins/intel_gpu/tests/functional/custom_op/custom_op_dynamic.cpp b/src/plugins/intel_gpu/tests/functional/custom_op/custom_op_dynamic.cpp new file mode 100644 index 00000000000000..7509d0fb5d5e70 --- /dev/null +++ b/src/plugins/intel_gpu/tests/functional/custom_op/custom_op_dynamic.cpp @@ -0,0 +1,256 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "openvino/op/constant.hpp" +#include "openvino/runtime/core.hpp" +#include "openvino/runtime/exec_model_info.hpp" +#include "openvino/runtime/properties.hpp" +#include "shared_test_classes/base/ov_behavior_test_utils.hpp" + +using namespace ::testing; + +namespace ov { +namespace test { +namespace intel_gpu { + +class CustomAddOp : public ov::op::Op { +private: + float m_alpha; + float m_beta; + +public: + OPENVINO_OP("CustomAddOp", "gpu_opset"); + + CustomAddOp() = default; + + CustomAddOp(const ov::Output& input, float alpha, float beta) : Op({input}), m_alpha(alpha), m_beta(beta) { + constructor_validate_and_infer_types(); + } + + void validate_and_infer_types() override { + set_output_size(1); + set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); + } + + bool visit_attributes(ov::AttributeVisitor& visitor) override { + visitor.on_attribute("alpha", m_alpha); + visitor.on_attribute("beta", m_beta); + return true; + } + + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override { + OPENVINO_ASSERT(new_args.size() == 1, "Incorrect number of new arguments"); + return std::make_shared(new_args[0], m_alpha, m_beta); + } + + bool has_evaluate() const override { + return true; + } + + bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override { + auto in = inputs[0]; + auto out = outputs[0]; + out.set_shape(in.get_shape()); + for (size_t i = 0; i < out.get_size(); i++) { + out.data()[i] = in.data()[i] * m_alpha + m_beta; + } + return true; + } +}; + +using CustomOpDynamicTestParams = std::tuple, // input shape + std::vector>>; // input data +class CustomOpDynamic : public ov::test::TestsCommon, public testing::WithParamInterface { + void SetUp() override { + generate_config_files(); + }; + + void TearDown() override { + ov::test::utils::removeFile(config_cl); + ov::test::utils::removeFile(config_xml); + } + +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + std::vector input_shapes; + std::vector> input_datas; + std::tie(input_shapes, input_datas) = obj.param; + + std::ostringstream result; + result << "input_shape="; + for (auto shape : input_shapes) { + result << shape; + } + return result.str(); + } + + static const size_t dim1 = 1; + void run() { + std::vector input_shapes; + std::vector> input_datas; + std::tie(input_shapes, input_datas) = GetParam(); + ASSERT_TRUE(input_shapes.size() == input_datas.size()); + + ov::Core core; + float alpha = 1.0, beta = 0.1; + auto model = generate_model_with_custom_add_op(alpha, beta, ov::PartialShape{-1, dim1, -1}); + + ov::AnyMap config = {ov::hint::inference_precision(ov::element::f32), {"CONFIG_FILE", config_xml}}; + auto compiled_model = core.compile_model(model, ov::test::utils::DEVICE_GPU, config); + + auto runtime_graph = compiled_model.get_runtime_model(); + auto ops = runtime_graph->get_ordered_ops(); + + bool found_custom_op = false; + for (auto op : ops) { + if (op->get_rt_info()[ov::exec_model_info::LAYER_TYPE].as() == "CustomGPUPrimitive") { + found_custom_op = true; + break; + } + } + ASSERT_TRUE(found_custom_op); + + auto ireq = compiled_model.create_infer_request(); + for (size_t i = 0; i < input_datas.size(); i++) { + auto input = ov::Tensor({ov::element::f32}, input_shapes[i], input_datas[i].data()); + ireq.set_input_tensor(0, input); + ireq.infer(); + auto output = ireq.get_output_tensor(0); + std::vector actual(output.data(), output.data() + output.get_size()); + + ASSERT_EQ(output.get_element_type(), element::f32); + + float* inp_data = input.data(); + for (size_t i = 0; i < output.get_size(); i++) { + ASSERT_FLOAT_EQ(actual[i], inp_data[i] * alpha + beta); + } + } + } + +protected: + std::string config_cl; + std::string config_xml; + + void generate_config_files() { + config_cl = ov::test::utils::generateTestFilePrefix() + "_custom_op_dynamic.cl"; + config_xml = ov::test::utils::generateTestFilePrefix() + "_custom_op_dynamic.xml"; + + std::string content_cl = R"( + __kernel void custom_add_kernel( + __global const INPUT0_TYPE* inp0, + __global OUTPUT0_TYPE* outp) { + const uint b = (uint)get_global_id(0); + const uint f = (uint)get_global_id(1); + const uint y = (uint)get_global_id(2); + #if INPUT0_DIMS_SIZE == 4 + const uint x = 0; + #endif + + const unsigned src_index = b*INPUT0_DIMS[1]*INPUT0_DIMS[2]*INPUT0_DIMS[3] + f*INPUT0_DIMS[2]*INPUT0_DIMS[3] + y*INPUT0_DIMS[3] + x; + const unsigned dst_index = src_index; + + outp[dst_index] = inp0[src_index] * alpha + beta; + })"; + + std::string content_xml = R"( + + + + + + + + + + + + + )"; + + ov::test::utils::createFile(config_cl, content_cl); + ov::test::utils::createFile(config_xml, content_xml); + } + + std::shared_ptr generate_model_with_custom_add_op(float alpha, float beta, ov::PartialShape inp_shape) { + auto input = std::make_shared(ov::element::f32, inp_shape); + auto op = std::make_shared(input, alpha, beta); + auto result = std::make_shared(op); + return std::make_shared(ov::ResultVector{result}, ov::ParameterVector{input}, "model_with_custom_op_dynamic"); + } +}; + +class CustomOpStatic : public CustomOpDynamic { +public: + void run() { + std::vector input_shapes; + std::vector> input_datas; + std::tie(input_shapes, input_datas) = GetParam(); + ASSERT_EQ(input_shapes.size(), input_datas.size()); + ASSERT_EQ(input_shapes.size(), 1u); + + ov::Core core; + float alpha = 1.0, beta = 0.1; + auto model = generate_model_with_custom_add_op(alpha, beta, ov::PartialShape(input_shapes[0])); + + ov::AnyMap config = {ov::hint::inference_precision(ov::element::f32), {"CONFIG_FILE", config_xml}}; + auto compiled_model = core.compile_model(model, ov::test::utils::DEVICE_GPU, config); + + auto runtime_graph = compiled_model.get_runtime_model(); + auto ops = runtime_graph->get_ordered_ops(); + + bool found_custom_op = false; + for (auto op : ops) { + if (op->get_rt_info()[ov::exec_model_info::LAYER_TYPE].as() == "CustomGPUPrimitive") { + found_custom_op = true; + break; + } + } + ASSERT_TRUE(found_custom_op); + + auto ireq = compiled_model.create_infer_request(); + auto input = ov::Tensor({ov::element::f32}, input_shapes[0], input_datas[0].data()); + ireq.set_input_tensor(0, input); + ireq.infer(); + auto output = ireq.get_output_tensor(0); + std::vector actual(output.data(), output.data() + output.get_size()); + + ASSERT_EQ(output.get_element_type(), element::f32); + + float* inp_data = input.data(); + for (size_t i = 0; i < output.get_size(); i++) { + ASSERT_FLOAT_EQ(actual[i], inp_data[i] * alpha + beta); + } + } +}; + +TEST_P(CustomOpDynamic, Accuracy) { + run(); +} + +TEST_P(CustomOpStatic, Accuracy) { + run(); +} + +const std::vector input_shapes{{1, CustomOpDynamic::dim1, 2}, {2, CustomOpDynamic::dim1, 3}}; +const std::vector> input_datas{{0.2, 0.4}, {0.2, 0.4, 0.3, 0.5, 0.7, 0.9}}; + +INSTANTIATE_TEST_SUITE_P(smoke_GPU_Accuracy, + CustomOpDynamic, + ::testing::Combine(::testing::Values(input_shapes), ::testing::Values(input_datas)), + CustomOpDynamic::getTestCaseName); + +const std::vector input_static_shapes{{2, 2, 3}}; +const std::vector> input_static_datas{{0.2, 0.4, 0.3, 0.5, 0.7, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6}}; + +INSTANTIATE_TEST_SUITE_P(smoke_GPU_Accuracy, + CustomOpStatic, + ::testing::Combine(::testing::Values(input_static_shapes), ::testing::Values(input_static_datas)), + CustomOpStatic::getTestCaseName); + +} // namespace intel_gpu +} // namespace test +} // namespace ov diff --git a/src/plugins/intel_gpu/tests/unit/CMakeLists.txt b/src/plugins/intel_gpu/tests/unit/CMakeLists.txt index aa40a800295f02..1616c9e0bb687a 100644 --- a/src/plugins/intel_gpu/tests/unit/CMakeLists.txt +++ b/src/plugins/intel_gpu/tests/unit/CMakeLists.txt @@ -31,6 +31,7 @@ file(GLOB_RECURSE SOURCES_MAIN "${CMAKE_HOME_DIRECTORY}/src/plugins/intel_gpu/src/plugin/remote_tensor.cpp" "${CMAKE_HOME_DIRECTORY}/src/plugins/intel_gpu/src/plugin/usm_host_tensor.cpp" "${CMAKE_HOME_DIRECTORY}/src/plugins/intel_gpu/src/plugin/common_utils.cpp" + "${CMAKE_HOME_DIRECTORY}/src/plugins/intel_gpu/src/plugin/simple_math.cpp" ) if (NOT ENABLE_ONEDNN_FOR_GPU)