diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 495f6754b7e..b112432f95e 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -17,6 +17,11 @@ set(KINETO_URL use_mirror(VARIABLE KINETO_URL URL ${KINETO_URL}) set(KINETO_MD5 f9b550591b3899fb267270c19484933f) +set(CUDNN_FRONTEND_URL + https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v0.9.1.zip) +use_mirror(VARIABLE CUDNN_FRONTEND_URL URL ${CUDNN_FRONTEND_URL}) +set(CUDNN_FRONTEND_MD5 0d28ff6aaa984dac4f7d16acfc48de72) + set(EXTERNAL_TARGETS) if(WITH_TBB) # set(WITH_${threading_runtime_item} ON) in threading.cmake @@ -33,6 +38,11 @@ list(APPEND EXTERNAL_TARGETS fmt) add_subdirectory(kineto) list(APPEND EXTERNAL_TARGETS kineto) +if(BUILD_CUDA) + add_subdirectory(cudnn_frontend) + list(APPEND EXTERNAL_TARGETS cudnn_frontend) +endif() + mark_targets_as_system(${EXTERNAL_TARGETS}) set_property(GLOBAL PROPERTY EXTERNAL_TARGETS ${EXTERNAL_TARGETS}) diff --git a/external/cudnn_frontend/CMakeLists.txt b/external/cudnn_frontend/CMakeLists.txt new file mode 100644 index 00000000000..7375f373f57 --- /dev/null +++ b/external/cudnn_frontend/CMakeLists.txt @@ -0,0 +1,16 @@ +include(FetchContent) +FetchContent_Declare( + cudnn_frontend + URL ${CUDNN_FRONTEND_URL} + URL_HASH MD5=${CUDNN_FRONTEND_MD5} +) +set(CUDNN_FRONTEND_BUILD_SAMPLES OFF) +FetchContent_MakeAvailable(cudnn_frontend) + +set(CUDNN_FRONTEND_INSTALL_DIR ${THIRD_PARTY_DIR}/cudnn_frontend) +install( + TARGETS cudnn_frontend + EXPORT oneflow + LIBRARY DESTINATION ${CUDNN_FRONTEND_INSTALL_DIR}/lib + ARCHIVE DESTINATION ${CUDNN_FRONTEND_INSTALL_DIR}/lib +) diff --git a/oneflow/core/device/cudnn_conv_util.cpp b/oneflow/core/device/cudnn_conv_util.cpp index 849551b686f..39f8e7e83a4 100644 --- a/oneflow/core/device/cudnn_conv_util.cpp +++ b/oneflow/core/device/cudnn_conv_util.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef WITH_CUDA +#include "oneflow/core/framework/infer_util.h" #include "oneflow/core/device/cudnn_conv_util.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/common/cached_caller.h" @@ -22,6 +23,7 @@ limitations under the License. #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/framework/op_kernel.h" +#include "oneflow/core/job/lazy_mode.h" namespace oneflow { @@ -82,6 +84,7 @@ perf_t GetBestAlgorithm(const CudnnConvArgs& args, CudnnConvResource* res, FOR_RANGE(size_t, i, 0, perf_vec.size()) { // Note: Shouldn't all returned results be successful? CHECK_EQ(perf_vec[i].status, CUDNN_STATUS_SUCCESS); + // TODO workspace size limit will lead to dismatch result with pytorch for large tensor if (perf_vec[i].memory > args.params.max_ws_size) { continue; } if (args.deterministic && perf_vec[i].determinism == CUDNN_NON_DETERMINISTIC) { continue; } found_algo_idx = i; @@ -332,6 +335,106 @@ CudnnConvArgs::CudnnConvArgs(const user_op::KernelComputeContext& ctx, DataType params.max_ws_size = max_workspace_size; } +cudnn_frontend::Tensor GetTensorDescriptor(const user_op::Tensor* t, const int64_t id) { + auto dim = t->shape_view(); + auto stride = t->stride(); + return cudnn_frontend::TensorBuilder() + .setDim(dim.size(), dim.data()) + .setStride(stride.size(), stride.data()) + .setId(id) + .setAlignment(32) + .setDataType(GetCudnnDataType(t->data_type())) + .build(); +} + +cudnn_frontend::Tensor GetTensorDescriptor(const user_op::TensorDesc& t, const int64_t id) { + auto dim = t.shape(); + auto stride = t.stride(); + return cudnn_frontend::TensorBuilder() + .setDim(dim.size(), dim.data()) + .setStride(stride.size(), stride.data()) + .setId(id) + .setAlignment(32) + .setDataType(GetCudnnDataType(t.data_type())) + .build(); +} + +cudnn_frontend::ConvDesc GetConvDescriptor(const user_op::InferContext& ctx, + cudnnDataType_t data_type) { + if (data_type == CUDNN_DATA_HALF || data_type == CUDNN_DATA_BFLOAT16) { + data_type = CUDNN_DATA_FLOAT; + } + + std::vector padding; + const auto& padding_before = ctx.Attr>("padding_before"); + copy(padding_before.begin(), padding_before.end(), back_inserter(padding)); + + std::vector stride; + const auto& strides = ctx.Attr>("strides"); + copy(strides.begin(), strides.end(), back_inserter(stride)); + + std::vector dilation; + const auto& dilation_rate = ctx.Attr>("dilation_rate"); + copy(dilation_rate.begin(), dilation_rate.end(), back_inserter(dilation)); + + uint64_t ndim = stride.size(); + return cudnn_frontend::ConvDescBuilder() + .setDataType(data_type) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(ndim) + .setStrides(ndim, stride.data()) + .setPrePadding(ndim, padding.data()) + .setPostPadding(ndim, padding.data()) + .setDilation(ndim, dilation.data()) + .build(); +} + +cudnn_frontend::ConvDesc GetConvDescriptor(const user_op::KernelComputeContext& ctx, + cudnnDataType_t data_type) { + if (data_type == CUDNN_DATA_HALF || data_type == CUDNN_DATA_BFLOAT16) { + data_type = CUDNN_DATA_FLOAT; + } + + std::vector padding; + const auto& padding_before = ctx.Attr>("padding_before"); + copy(padding_before.begin(), padding_before.end(), back_inserter(padding)); + + std::vector stride; + const auto& strides = ctx.Attr>("strides"); + copy(strides.begin(), strides.end(), back_inserter(stride)); + + std::vector dilation; + const auto& dilation_rate = ctx.Attr>("dilation_rate"); + copy(dilation_rate.begin(), dilation_rate.end(), back_inserter(dilation)); + + uint64_t ndim = stride.size(); + return cudnn_frontend::ConvDescBuilder() + .setDataType(data_type) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(ndim) + .setStrides(ndim, stride.data()) + .setPrePadding(ndim, padding.data()) + .setPostPadding(ndim, padding.data()) + .setDilation(ndim, dilation.data()) + .build(); +} + +CudnnConvArgsV8::CudnnConvArgsV8(const user_op::InferContext& ctx, const user_op::TensorDesc& x, + const user_op::TensorDesc& y, const user_op::TensorDesc& w) + : xdesc(GetTensorDescriptor(x, 'x')), + ydesc(GetTensorDescriptor(y, 'y')), + wdesc(GetTensorDescriptor(w, 'w')), + cdesc(GetConvDescriptor(ctx, GetCudnnDataType(y.data_type()))), + beta(0.0f) {} + +CudnnConvArgsV8::CudnnConvArgsV8(const user_op::KernelComputeContext& ctx, const user_op::Tensor* x, + const user_op::Tensor* y, const user_op::Tensor* w) + : xdesc(GetTensorDescriptor(x, 'x')), + ydesc(GetTensorDescriptor(y, 'y')), + wdesc(GetTensorDescriptor(w, 'w')), + cdesc(GetConvDescriptor(ctx, GetCudnnDataType(y->data_type()))), + beta(0.0f) {} + ManagedCudnnConvResource::ManagedCudnnConvResource(const CudnnConvArgs& args) : handle_(nullptr), x_dptr_(nullptr), w_dptr_(nullptr), y_dptr_(nullptr), ws_dptr_(nullptr) { x_byte_size_ = ByteSize4Tensor(args.params.x_dims, args.params.x_ndim, args.params.x_data_type); @@ -424,6 +527,161 @@ cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvReso args.wdesc.Get(), algo, sz); } +cudnn_frontend::OperationGraph BuildConvOpGraph(const cudnnHandle_t handle, + const cudnnBackendDescriptorType_t desc, + const cudnn_frontend::Tensor& xdesc, + const cudnn_frontend::Tensor& ydesc, + const cudnn_frontend::Tensor& wdesc, + const cudnn_frontend::ConvDesc& cdesc, float beta) { + auto conv_op = cudnn_frontend::OperationBuilder(desc) + .setxDesc(xdesc) + .setyDesc(ydesc) + .setwDesc(wdesc) + .setcDesc(cdesc) + .setBeta(beta) + .build(); + std::array ops = {&conv_op}; + auto op_graph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle) + .setOperationGraph(ops.size(), ops.data()) + .build(); + return op_graph; +} + +void FilterEngineConfigs(cudnn_frontend::EngineConfigList& from, + cudnn_frontend::EngineConfigList& to, bool deterministic) { + auto filter = [=](cudnnBackendDescriptor_t c) { + if (deterministic) { + if (cudnn_frontend::hasNumericalNote(c)) { + return true; + } + } + if (cudnn_frontend::hasNumericalNote(c)) { + return true; + } + return false; + }; + cudnn_frontend::filter(from, to, filter); +} + +std::vector GetGeneratorSources( + const cudnnBackendDescriptorType_t desc) { + bool deterministic = Singleton::Get() + ->resource() + .cudnn_conf() + .cudnn_conv_use_deterministic_algo_only(); + bool heuristic = ParseBooleanFromEnv("ONEFLOW_CUDNN_USE_HEURISTIC_MODE_B", false); + auto heur_mode = heuristic ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_INSTANT; + // Method for engine config generator based on heuristics + const auto heurgen_method = + [deterministic, + heur_mode](cudnn_frontend::OperationGraph& opGraph) -> cudnn_frontend::EngineConfigList { + auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() + .setOperationGraph(opGraph) + .setHeurMode(heur_mode) + .build(); + auto& engine_configs = heuristics.getEngineConfig(heuristics.getEngineConfigCount()); + cudnn_frontend::EngineConfigList filtered_configs; + FilterEngineConfigs(engine_configs, filtered_configs, deterministic); + return filtered_configs; + }; + // Method for engine config generator based on fallback list + const auto fallback_method = + [desc, + deterministic](cudnn_frontend::OperationGraph& opGraph) -> cudnn_frontend::EngineConfigList { + auto fallback = cudnn_frontend::EngineFallbackListBuilder() + .setOperationGraph(opGraph) + .setOperation(desc) + .build(); + auto& fallback_list = fallback.getFallbackList(); + cudnn_frontend::EngineConfigList filtered_configs; + FilterEngineConfigs(fallback_list, filtered_configs, deterministic); + return filtered_configs; + }; + std::vector sources = {heurgen_method, fallback_method}; + return sources; +} + +cudnn_frontend::EngineConfigList CudnnFrontendGetConfigs(const cudnnHandle_t handle, + const cudnnBackendDescriptorType_t desc, + const cudnn_frontend::Tensor& xdesc, + const cudnn_frontend::Tensor& ydesc, + const cudnn_frontend::Tensor& wdesc, + const cudnn_frontend::ConvDesc& cdesc, + float beta, std::string& tag) { + auto op_graph = BuildConvOpGraph(handle, desc, xdesc, ydesc, wdesc, cdesc, beta); + tag = op_graph.getTag(); + auto sources = GetGeneratorSources(desc); + cudnn_frontend::EngineConfigGenerator generator(sources.size(), sources.data()); + auto configs = generator.generate_engine_config(op_graph); + return configs; +} + +bool PlanErrataException(const cudnnHandle_t handle, const std::string& executionPlanTag) { + static nlohmann::json errata_json_handle; + static bool has_json = cudnn_frontend::load_from_config(errata_json_handle, ""); + if (!has_json) { + return false; + } else { + return cudnn_frontend::check_errata(errata_json_handle, executionPlanTag, handle, + []() { return true; }); + } +} + +void RunConvPlan(const cudnnHandle_t handle, user_op::Tensor* x, user_op::Tensor* y, + user_op::Tensor* w, user_op::Tensor* buf, + const cudnn_frontend::ExecutionPlan& plan) { + void* data[] = {x->mut_dptr(), y->mut_dptr(), w->mut_dptr()}; + int64_t ids[] = {'x', 'y', 'w'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(buf->mut_dptr()) + .setDataPointers(3, data) + .setUids(3, ids) + .build(); + OF_CUDNN_CHECK(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc())); +} + +void TryConfigs(const cudnnHandle_t handle, user_op::Tensor* x, user_op::Tensor* y, + user_op::Tensor* w, user_op::Tensor* buf, cudnn_frontend::EngineConfigList& configs, + const std::string& tag) { + for (auto& config : configs) { + try { + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(config, tag) + .build(); + if (PlanErrataException(handle, plan.getTag())) { continue; } + RunConvPlan(handle, x, y, w, buf, plan); + return; + } catch (cudnn_frontend::cudnnException& e) {} + } +} + +void CudnnFrontendRunConv(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, + user_op::Tensor* x, user_op::Tensor* y, user_op::Tensor* w, + user_op::Tensor* b, const CudnnConvArgsV8& args) { + std::string tag; + auto configs = CudnnFrontendGetConfigs(handle, desc, args.xdesc, args.ydesc, args.wdesc, + args.cdesc, args.beta, tag); + TryConfigs(handle, x, y, w, b, configs, tag); +} + +size_t GetCudnnConvWorkspaceSizeV8(const cudnnHandle_t handle, + cudnn_frontend::EngineConfigList& configs, + const std::string& tag) { + for (auto& config : configs) { + try { + auto plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(config, tag) + .build(); + if (PlanErrataException(handle, plan.getTag())) { continue; } + if (plan.getWorkspaceSize() > 0L) { return plan.getWorkspaceSize(); } + } catch (cudnn_frontend::cudnnException& e) {} + } + return 1L; +} + template<> struct CudnnConvAlgorithmSearch { using perf_t = cudnnConvolutionFwdAlgoPerf_t; diff --git a/oneflow/core/device/cudnn_conv_util.h b/oneflow/core/device/cudnn_conv_util.h index e917572580b..90946a61a9e 100644 --- a/oneflow/core/device/cudnn_conv_util.h +++ b/oneflow/core/device/cudnn_conv_util.h @@ -18,8 +18,12 @@ limitations under the License. #ifdef WITH_CUDA +#include "cudnn_frontend.h" +#include "cudnn_frontend_EngineConfigGenerator.h" +#include "oneflow/core/common/tensor_desc.h" #include "oneflow/core/device/cudnn_util.h" #include "oneflow/core/common/protobuf.h" +#include "oneflow/core/framework/user_op_tensor.h" namespace oneflow { @@ -93,6 +97,20 @@ struct CudnnConvArgs final { bool enable_pseudo_half); }; +struct CudnnConvArgsV8 final { + cudnn_frontend::Tensor xdesc; + cudnn_frontend::Tensor ydesc; + cudnn_frontend::Tensor wdesc; + cudnn_frontend::ConvDesc cdesc; + float beta; + + OF_DISALLOW_COPY_AND_MOVE(CudnnConvArgsV8); + explicit CudnnConvArgsV8(const user_op::InferContext& ctx, const user_op::TensorDesc& x, + const user_op::TensorDesc& y, const user_op::TensorDesc& w); + explicit CudnnConvArgsV8(const user_op::KernelComputeContext& ctx, const user_op::Tensor* x, + const user_op::Tensor* y, const user_op::Tensor* w); +}; + class CudnnConvResource { public: CudnnConvResource() = default; @@ -168,6 +186,22 @@ cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvReso cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res, cudnnConvolutionBwdFilterAlgo_t algo, size_t* sz); +cudnn_frontend::EngineConfigList CudnnFrontendGetConfigs(const cudnnHandle_t handle, + const cudnnBackendDescriptorType_t desc, + const cudnn_frontend::Tensor& xdesc, + const cudnn_frontend::Tensor& ydesc, + const cudnn_frontend::Tensor& wdesc, + const cudnn_frontend::ConvDesc& cdesc, + float beta, std::string& tag); + +void CudnnFrontendRunConv(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, + user_op::Tensor* x, user_op::Tensor* y, user_op::Tensor* w, + user_op::Tensor* b, const CudnnConvArgsV8& args); + +size_t GetCudnnConvWorkspaceSizeV8(const cudnnHandle_t handle, + cudnn_frontend::EngineConfigList& configs, + const std::string& tag); + template perf_t FindCudnnConvAlgorithm(CudnnConvArgs* args); diff --git a/oneflow/user/kernels/conv_cudnn_kernels.cpp b/oneflow/user/kernels/conv_cudnn_kernels.cpp index b98c57afcb8..f8475c30202 100644 --- a/oneflow/user/kernels/conv_cudnn_kernels.cpp +++ b/oneflow/user/kernels/conv_cudnn_kernels.cpp @@ -103,6 +103,7 @@ size_t InferTmpSizeWithCudnn(const user_op::TensorDesc* x, const user_op::Tensor CHECK_EQ(algo_perf.status, CUDNN_STATUS_SUCCESS) << "op (" << ctx.op_name() << ") find algorithm perference failed. algo: " << algo_perf.algo; + // TODO workspace size limit will lead to dismatch result with pytorch for large tensor CHECK_LE(algo_perf.memory, workspace_size) << "op (" << ctx.op_name() << ") find algorithm " << algo_perf.algo << ", need memory " << algo_perf.memory << ", but cudnn_buf_limit_byte is " << workspace_size; @@ -252,6 +253,109 @@ REGISTER_CONV_KERNEL(conv1d, 1); REGISTER_CONV_KERNEL(conv2d, 2); REGISTER_CONV_KERNEL(conv3d, 3); +template +class ConvGpuKernelV8 final : public user_op::OpKernel, public user_op::CudaGraphSupport { + public: + ConvGpuKernelV8() = default; + ~ConvGpuKernelV8() = default; + + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + + std::shared_ptr CreateConvCudnnOpKernelCache( + user_op::KernelCacheContext* ctx) const { + const auto& data_format = ctx->Attr("data_format"); + int32_t filters = ctx->Attr("filters"); + + std::shared_ptr state(new ConvCudnnOpKernelCache()); + + const user_op::TensorDesc* bias = ctx->TensorDesc4ArgNameAndIndex("bias", 0); + if (bias != nullptr) { + state->bias_desc.reset( + GetBiasCudnnTensorDesc(data_format, filters, bias->data_type())); + } + + return state; + } + + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateConvCudnnOpKernelCache(ctx); + } + + private: + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + // process context data + auto input = ctx->Tensor4ArgNameAndIndex("in", 0); + auto output = ctx->Tensor4ArgNameAndIndex("out", 0); + auto weight = ctx->Tensor4ArgNameAndIndex("weight", 0); + auto buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + + if (input->shape_view().elem_cnt() == 0) return; + + CudnnConvArgsV8 args(*ctx, input, output, weight); + // process add_to_output + if (ctx->has_input("_add_to_output", 0)) { + auto add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); + Memcpy( + ctx->stream(), output->mut_dptr(), add_to_output->dptr(), + add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); + args.beta = 1.0f; + } + + // trigger conv compute + auto handle = ctx->stream()->As()->cudnn_handle(); + CudnnFrontendRunConv(handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, input, + output, weight, buffer, args); + + // process bias + auto bias = ctx->Tensor4ArgNameAndIndex("bias", 0); + if (bias != nullptr) { + auto conv_cache = dynamic_cast(cache); + CHECK_NOTNULL(conv_cache); + const auto& data_format = ctx->Attr("data_format"); + CudnnTensorDesc output_desc(output->data_type(), output->shape_view(), data_format); + OF_CUDNN_CHECK(cudnnAddTensor(ctx->stream()->As()->cudnn_handle(), + CudnnSPOnePtr(input->data_type()), conv_cache->bias_desc->Get(), + bias->dptr(), CudnnSPOnePtr(input->data_type()), + output_desc.Get(), output->mut_dptr())); + } + } + + bool IsCudaGraphSupported(user_op::KernelInitContext* ctx, + user_op::OpKernelState* state) const override { + return Singleton::Get() + ->resource() + .cudnn_conf() + .cudnn_conv_heuristic_search_algo(); + } +}; + +#define REGISTER_CONV_KERNEL_V8(op_name, ndims) \ + REGISTER_USER_KERNEL(#op_name) \ + .SetCreateFn>() \ + .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA \ + && user_op::HobEnvBool("ONEFLOW_KERNEL_ENABLE_CUDNN_V8", false)) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ + auto& input = ctx->InputTensorDesc("in", 0); \ + auto& output = ctx->InputTensorDesc("out", 0); \ + auto& weight = ctx->InputTensorDesc("weight", 0); \ + CudnnConvArgsV8 args(*ctx, input, output, weight); \ + auto handle = Singleton::Get()->Get(); \ + std::string tag; \ + auto configs = CudnnFrontendGetConfigs( \ + handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, args.xdesc, \ + args.ydesc, args.wdesc, args.cdesc, args.beta, tag); \ + size_t workspace_size = GetCudnnConvWorkspaceSizeV8(handle, configs, tag); \ + Singleton::Get()->Put(handle); \ + return workspace_size; \ + }) \ + .SetPriority(user_op::kKernelPriorityOptimized); + +REGISTER_CONV_KERNEL_V8(conv1d, 1); +REGISTER_CONV_KERNEL_V8(conv2d, 2); +REGISTER_CONV_KERNEL_V8(conv3d, 3); + class ConvDataGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: OF_DISALLOW_COPY_AND_MOVE(ConvDataGradGpuKernel); @@ -325,6 +429,75 @@ REGISTER_USER_KERNEL("conv_data_grad") return Maybe::Ok(); }); +class ConvDataGradGpuKernelV8 final : public user_op::OpKernel, public user_op::CudaGraphSupport { + public: + OF_DISALLOW_COPY_AND_MOVE(ConvDataGradGpuKernelV8); + ConvDataGradGpuKernelV8() = default; + ~ConvDataGradGpuKernelV8() = default; + + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + auto input_diff = ctx->Tensor4ArgNameAndIndex("dx", 0); + auto output_diff = ctx->Tensor4ArgNameAndIndex("dy", 0); + auto weight = ctx->Tensor4ArgNameAndIndex("filter", 0); + auto buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + + if (output_diff->shape_view().elem_cnt() == 0) return; + + CudnnConvArgsV8 args(*ctx, input_diff, output_diff, weight); + // process add_to_output + if (ctx->has_input("_add_to_output", 0)) { + auto add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); + Memcpy( + ctx->stream(), input_diff->mut_dptr(), add_to_output->dptr(), + add_to_output->shape_view().elem_cnt() * GetSizeOfDataType(add_to_output->data_type())); + args.beta = 1.0f; + } + + // trigger conv compute + auto handle = ctx->stream()->As()->cudnn_handle(); + CudnnFrontendRunConv(handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR, + input_diff, output_diff, weight, buffer, args); + } + + bool IsCudaGraphSupported(user_op::KernelInitContext* ctx, + user_op::OpKernelState* state) const override { + return Singleton::Get() + ->resource() + .cudnn_conf() + .cudnn_conv_heuristic_search_algo(); + } +}; + +REGISTER_USER_KERNEL("conv_data_grad") + .SetCreateFn() + .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA + && user_op::HobEnvBool("ONEFLOW_KERNEL_ENABLE_CUDNN_V8", false)) + .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { + auto& input_diff = ctx->InputTensorDesc("dx", 0); + auto& output_diff = ctx->InputTensorDesc("dy", 0); + auto& weight = ctx->InputTensorDesc("filter", 0); + CudnnConvArgsV8 args(*ctx, input_diff, output_diff, weight); + auto handle = Singleton::Get()->Get(); + std::string tag; + auto configs = CudnnFrontendGetConfigs( + handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR, args.xdesc, + args.ydesc, args.wdesc, args.cdesc, args.beta, tag); + size_t workspace_size = GetCudnnConvWorkspaceSizeV8(handle, configs, tag); + Singleton::Get()->Put(handle); + return workspace_size; + }) + .SetInplaceProposalFn([](const user_op::InferContext& ctx, + const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe { + if (ctx.has_input("_add_to_output", 0)) { + OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "_add_to_output", 0, true)); + } + return Maybe::Ok(); + }) + .SetPriority(user_op::kKernelPriorityOptimized); + class ConvFilterGradGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { public: OF_DISALLOW_COPY_AND_MOVE(ConvFilterGradGpuKernel); @@ -384,6 +557,65 @@ REGISTER_USER_KERNEL("conv_filter_grad") cudnn_conf.cudnn_conv_force_bwd_filter_algo()); }); +class ConvFilterGradGpuKernelV8 final : public user_op::OpKernel, public user_op::CudaGraphSupport { + public: + OF_DISALLOW_COPY_AND_MOVE(ConvFilterGradGpuKernelV8); + ConvFilterGradGpuKernelV8() = default; + ~ConvFilterGradGpuKernelV8() = default; + + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + auto input = ctx->Tensor4ArgNameAndIndex("x", 0); + auto output_diff = ctx->Tensor4ArgNameAndIndex("dy", 0); + auto weight_diff = ctx->Tensor4ArgNameAndIndex("filter_diff", 0); + auto buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + + if (input->shape_view().elem_cnt() == 0) { + Memset( + ctx->stream(), weight_diff->mut_dptr(), 0, + weight_diff->shape_view().elem_cnt() * GetSizeOfDataType(weight_diff->data_type())); + return; + } + + CudnnConvArgsV8 args(*ctx, input, output_diff, weight_diff); + + // trigger conv compute + auto handle = ctx->stream()->As()->cudnn_handle(); + CudnnFrontendRunConv(handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR, + input, output_diff, weight_diff, buffer, args); + } + + bool IsCudaGraphSupported(user_op::KernelInitContext* ctx, + user_op::OpKernelState* state) const override { + return Singleton::Get() + ->resource() + .cudnn_conf() + .cudnn_conv_heuristic_search_algo(); + } +}; + +REGISTER_USER_KERNEL("conv_filter_grad") + .SetCreateFn() + .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA + && user_op::HobEnvBool("ONEFLOW_KERNEL_ENABLE_CUDNN_V8", false)) + .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { + auto& input = ctx->InputTensorDesc("x", 0); + auto& output_diff = ctx->InputTensorDesc("dy", 0); + auto& weight_diff = ctx->InputTensorDesc("filter_diff", 0); + CudnnConvArgsV8 args(*ctx, input, output_diff, weight_diff); + auto handle = Singleton::Get()->Get(); + std::string tag; + auto configs = CudnnFrontendGetConfigs( + handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR, args.xdesc, + args.ydesc, args.wdesc, args.cdesc, args.beta, tag); + size_t workspace_size = GetCudnnConvWorkspaceSizeV8(handle, configs, tag); + Singleton::Get()->Put(handle); + return workspace_size; + }) + .SetPriority(user_op::kKernelPriorityOptimized); + struct ConvBiasGradState final : public user_op::OpKernelState { std::unique_ptr bias_diff_desc; };