diff --git a/oneflow/core/autograd/gradient_funcs/scaled_dot_product_attention.cpp b/oneflow/core/autograd/gradient_funcs/scaled_dot_product_attention.cpp new file mode 100644 index 00000000000..b9e4181df2d --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/scaled_dot_product_attention.cpp @@ -0,0 +1,112 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/functional/functional.h" +#include "oneflow/core/common/container_util.h" +#if CUDA_VERSION >= 11070 + +namespace oneflow { + +namespace one { + +struct ScaledDotProductFlashAttentionCaptureState : public AutoGradCaptureState { + bool query_requires_grad = true; + bool key_requires_grad = true; + bool value_requires_grad = true; + size_t query_idx = 0; + size_t key_idx = 0; + size_t value_idx = 0; + size_t out_idx = 0; + size_t softmax_lse_idx = 0; + size_t rng_state_idx = 0; + float p_dropout = .0f; + float softmax_scale = .0f; + bool is_causal = false; +}; + +class ScaledDotProductFlashAttention + : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be None. "; + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); + } + + Maybe Capture(ScaledDotProductFlashAttentionCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 3) << "Input size should be equal to 3. "; + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->p_dropout = JUST(composed_attrs.GetAttr("p_dropout")); + ctx->softmax_scale = JUST(composed_attrs.GetAttr("softmax_scale")); + ctx->is_causal = JUST(composed_attrs.GetAttr("is_causal")); + ctx->query_requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad(); + ctx->key_requires_grad = JUST(oneflow::VectorAt(inputs, 1))->requires_grad(); + ctx->value_requires_grad = JUST(oneflow::VectorAt(inputs, 2))->requires_grad(); + ctx->query_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0))); + ctx->key_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 1))); + ctx->value_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 2))); + ctx->out_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 0))); + ctx->softmax_lse_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 1))); + ctx->rng_state_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 2))); + return Maybe::Ok(); + } + + Maybe Apply(const ScaledDotProductFlashAttentionCaptureState* ctx, + const TensorTuple& out_grads, TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 3) << "Out grads size should be equal to 3. "; + std::shared_ptr grads; + in_grads->resize(3); + grads = JUST(functional::ScaledDotProductFlashAttentionGrad( + JUST(oneflow::VectorAt(out_grads, 0)), + JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->query_idx)), + JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->key_idx)), + JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->value_idx)), + JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->out_idx)), + JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->softmax_lse_idx)), + JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->rng_state_idx)), ctx->p_dropout, + ctx->is_causal, ctx->softmax_scale)); + + if (ctx->query_requires_grad) { + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(oneflow::VectorAt(*grads, 0)); + } + if (ctx->key_requires_grad) { + JUST(oneflow::VectorAt(*in_grads, 1)) = JUST(oneflow::VectorAt(*grads, 1)); + } + if (ctx->value_requires_grad) { + JUST(oneflow::VectorAt(*in_grads, 2)) = JUST(oneflow::VectorAt(*grads, 2)); + } + + return Maybe::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("scaled_dot_product_flash_attention", + ScaledDotProductFlashAttention); + +} // namespace one + +} // namespace oneflow + +#endif // CUDA_VERSION >= 11070 diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index ea9bb659220..391eb3e8522 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -2692,6 +2692,10 @@ signature: "Tensor (Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, Float dropout_p=0.0, Bool is_causal=False, Float scale=None, Int64 seed=0) => ScaledDotProductFlashAttention" bind_python: True +- name: "scaled_dot_product_attention_grad" + signature: "TensorTuple (Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor softmax_lse, Tensor rng_state, Float dropout_p=0.0, Bool is_causal=False, Float scale=0.0) => ScaledDotProductFlashAttentionGrad" + bind_python: False + - name: "fused_multi_head_attention_inference" signature: "Tensor (Tensor query, Tensor key, Tensor value, Int64 num_heads, Bool causal=False, Int64 query_hidden_slice_start=0, Int64 query_hidden_slice_end=-1, Int64 key_hidden_slice_start=0, Int64 key_hidden_slice_end=-1, Int64 value_hidden_slice_start=0, Int64 value_hidden_slice_end=-1, Tensor attn_bias=None, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInference" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 648e0832e98..db638fca99d 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -5538,6 +5538,135 @@ class ScaledDotProductFlashAttentionFunctor { #endif // CUDA_VERSION >= 11070 }; +class ScaledDotProductFlashAttentionGradFunctor { + public: + ScaledDotProductFlashAttentionGradFunctor() { +#if CUDA_VERSION >= 11070 + op_ = CHECK_JUST(one::OpBuilder("scaled_dot_product_flash_attention_grad") + .Input("grad_out") + .Input("query") + .Input("key") + .Input("value") + .Input("out") + .Input("softmax_lse") + .Input("rng_state") + .Output("grad_q") + .Output("grad_k") + .Output("grad_v") + .Build()); +#endif + } + + Maybe operator()( + const std::shared_ptr& grad_out, const std::shared_ptr& query, + const std::shared_ptr& key, const std::shared_ptr& value, + const std::shared_ptr& out, const std::shared_ptr& softmax_lse, + const std::shared_ptr& rng_state, const float& dropout_p, const bool& is_causal, + const float& scale) const { +#if CUDA_VERSION >= 11070 + // grad_out(batch x q_sqe_len x num_heads x head_size) + // query (batch x q_seq_len x num_heads x head_size_padded) + // key (batch x kv_seq_len x num_heads_k x head_size_padded) + // value (batch x kv_seq_len x num_heads_k x head_size_padded) + // out (batch x kv_seq_len x num_heads x head_size_padded) + // softmax_lse (batch x num_heads x q_seq_len) + const auto head_size = grad_out->shape()->At(3); + const auto head_size_padded = query->shape()->At(3); + const auto batch_size = query->shape()->At(0); + const auto seqlen_q = query->shape()->At(1); + const auto seqlen_k = key->shape()->At(1); + const auto num_heads = query->shape()->At(2); + const auto num_heads_k = key->shape()->At(2); + CHECK_EQ_OR_RETURN(batch_size, key->shape()->At(0)) + << " key has different batch size from query."; + CHECK_EQ_OR_RETURN(batch_size, value->shape()->At(0)) + << " value has different batch size from query."; + CHECK_EQ_OR_RETURN(batch_size, grad_out->shape()->At(0)) + << " grad_out has different batch size from query."; + CHECK_EQ_OR_RETURN(batch_size, out->shape()->At(0)) + << " out has different batch size from query."; + CHECK_EQ_OR_RETURN(batch_size, softmax_lse->shape()->At(0)) + << " softmax_lse has different batch size from query."; + CHECK_EQ_OR_RETURN(num_heads, grad_out->shape()->At(2)) + << " grad_out has different num_heads from query."; + CHECK_EQ_OR_RETURN(num_heads, softmax_lse->shape()->At(1)) + << " softmax_lse has different num_heads from query."; + CHECK_EQ_OR_RETURN(num_heads_k, value->shape()->At(2)) + << " value has different num_heads from key."; + CHECK_EQ_OR_RETURN(seqlen_q, grad_out->shape()->At(1)) + << " grad_out has different seq_len from query."; + CHECK_EQ_OR_RETURN(seqlen_q, softmax_lse->shape()->At(2)) + << " softmax_lse has different seq_len from query."; + CHECK_EQ_OR_RETURN(head_size_padded, key->shape()->At(3)) + << " key has different head dims from query."; + CHECK_EQ_OR_RETURN(head_size_padded, value->shape()->At(3)) + << " key has different head dims from query."; + CHECK_EQ_OR_RETURN(head_size_padded, out->shape()->At(3)) + << " out has different head dims from query."; + + bool padded = head_size % 8; + + auto grad_out_ = padded ? JUST(pad_last_dim<8>(grad_out)) : grad_out; + + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("p_dropout", "softmax_scale", "is_causal", + "window_size_left", "window_size_right"); + attrs.SetAllAttrs(dropout_p, scale, is_causal, -1, -1); + + auto output = std::make_shared(3); + auto output_ = JUST(OpInterpUtil::Dispatch( + *op_, {grad_out_, query, key, value, out, softmax_lse, rng_state}, attrs)); + CHECK_EQ(output_->size(), 3); + auto grad_q_ = (*output_)[0]; + auto grad_k_ = (*output_)[1]; + auto grad_v_ = (*output_)[2]; + + std::shared_ptr grad_q_padded, grad_k_padded, grad_v_padded; + + bool expanded = num_heads != num_heads_k; + + grad_q_padded = grad_q_; + if (expanded) { + grad_k_padded = JUST(functional::ReduceSum( + JUST(functional::Reshape(grad_k_, {batch_size, seqlen_k, num_heads_k, + num_heads / num_heads_k, head_size_padded})), + {3}, false, grad_k_->dtype())); + grad_v_padded = JUST(functional::ReduceSum( + JUST(functional::Reshape(grad_v_, {batch_size, seqlen_k, num_heads_k, + num_heads / num_heads_k, head_size_padded})), + {3}, false, grad_v_->dtype())); + } else { + grad_k_padded = grad_k_; + grad_v_padded = grad_v_; + } + + auto grad_q = padded ? JUST(functional::Slice(grad_q_padded, {0, 0, 0, 0}, + {batch_size, seqlen_q, num_heads, head_size}, + {1, 1, 1, 1}, false)) + : grad_q_padded; + auto grad_k = padded ? JUST(functional::Slice(grad_k_padded, {0, 0, 0, 0}, + {batch_size, seqlen_k, num_heads_k, head_size}, + {1, 1, 1, 1}, false)) + : grad_k_padded; + auto grad_v = padded ? JUST(functional::Slice(grad_v_padded, {0, 0, 0, 0}, + {batch_size, seqlen_k, num_heads_k, head_size}, + {1, 1, 1, 1}, false)) + : grad_v_padded; + + (*output)[0] = grad_q; + (*output)[1] = grad_k; + (*output)[2] = grad_v; + return output; + +#endif // CUDA_VERSION >= 11070 + + UNIMPLEMENTED_THEN_RETURN() << "only support CUDA_VERSION >= 11070."; + } + + private: +#if CUDA_VERSION >= 11070 + std::shared_ptr op_; +#endif // CUDA_VERSION >= 11070 +}; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { @@ -5676,6 +5805,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("MultiTensorYoloV5WeightUpdate"); m.add_functor("FusedClipGrad"); m.add_functor("ScaledDotProductFlashAttention"); + m.add_functor( + "ScaledDotProductFlashAttentionGrad"); } } // namespace functional diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 9042aac9422..f2a91b4b032 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -2903,6 +2903,35 @@ def OneFlow_ScaledDotProductFlashAttentionOp : OneFlow_BaseOp<"scaled_dot_produc let has_data_type_infer_fn = 1; } +def OneFlow_ScaledDotProductFlashAttentionGradOp : OneFlow_BaseOp<"scaled_dot_product_flash_attention_grad", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$grad_out, + OneFlow_Tensor:$query, + OneFlow_Tensor:$key, + OneFlow_Tensor:$value, + OneFlow_Tensor:$out, + OneFlow_Tensor:$softmax_lse, + OneFlow_Tensor:$rng_state, + Optional:$alibi_slopes_ + ); + let output = (outs + OneFlow_Tensor:$grad_q, + OneFlow_Tensor:$grad_k, + OneFlow_Tensor:$grad_v + ); + let attrs = (ins + DefaultValuedAttr:$p_dropout, + DefaultValuedAttr:$softmax_scale, + DefaultValuedAttr:$is_causal, + SI32Attr:$window_size_left, + SI32Attr:$window_size_right + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + def OneFlow_FusedMultiHeadAttentionInferenceOp : OneFlow_BaseOp<"fused_multi_head_attention_inference", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$query, diff --git a/oneflow/user/kernels/scaled_dot_product_attention_grad_kernel.cu b/oneflow/user/kernels/scaled_dot_product_attention_grad_kernel.cu new file mode 100644 index 00000000000..36913bc027c --- /dev/null +++ b/oneflow/user/kernels/scaled_dot_product_attention_grad_kernel.cu @@ -0,0 +1,247 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include "oneflow/core/common/container_util.h" +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/data_type.pb.h" +#include "oneflow/core/common/just.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/shape_view.h" +#include "oneflow/core/common/throw.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/framework/op_kernel.h" +#include "oneflow/core/framework/user_op_tensor.h" + +#if CUDA_VERSION >= 11070 + +#ifdef WITH_CUTLASS + +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/ep/cuda/cuda_stream.h" +#include "oneflow/core/cuda/elementwise.cuh" +#include "oneflow/core/ep/include/primitive/permute.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/warp/mma.h" +#include "oneflow/core/kernel/cuda_graph_support.h" +#include "oneflow/user/kernels/random_seed_util.h" +#include "oneflow/user/kernels/scaled_dot_product_attention_kernel.h" +// from flash_attention +#include "oneflow/user/kernels/scaled_dot_product_attention_util.h" + +namespace oneflow { + +namespace user_op { + +namespace { + +static size_t InferTmpBufferSizeForFlashAttentionGradKernel(InferContext* ctx) { + const auto& q_shape = ctx->InputTensorDesc("query", 0).shape(); + const int batch_size = q_shape.At(0); + const int seqlen_q = q_shape.At(1); + const int num_heads = q_shape.At(2); + const int head_size = q_shape.At(3); + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + + size_t buffer_size = 0; + buffer_size += GetCudaAlignedSize(batch_size * num_heads * seqlen_q_rounded + * GetSizeOfDataType(DataType::kFloat)); + buffer_size += GetCudaAlignedSize(batch_size * seqlen_q_rounded * num_heads * head_size_rounded + * GetSizeOfDataType(DataType::kFloat)); + return buffer_size; +} + +class ScaledDotProductFlashAttentionGradKernel final : public user_op::OpKernel, + public user_op::CudaGraphSupport { + public: + ScaledDotProductFlashAttentionGradKernel() = default; + ~ScaledDotProductFlashAttentionGradKernel() override = default; + + private: + using user_op::OpKernel::Compute; + void Compute(user_op::KernelComputeContext* ctx) const override { + const Tensor* grad_out = ctx->Tensor4ArgNameAndIndex("grad_out", 0); + const Tensor* query = ctx->Tensor4ArgNameAndIndex("query", 0); + const Tensor* key = ctx->Tensor4ArgNameAndIndex("key", 0); + const Tensor* value = ctx->Tensor4ArgNameAndIndex("value", 0); + const Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + const Tensor* softmax_lse = ctx->Tensor4ArgNameAndIndex("softmax_lse", 0); + const Tensor* rng_state = ctx->Tensor4ArgNameAndIndex("rng_state", 0); + const Tensor* alibi_slopes_ = nullptr; + if (ctx->has_input("alibi_slopes_", 0)) { + alibi_slopes_ = ctx->Tensor4ArgNameAndIndex("alibi_slopes_", 0); + } + + const float p_dropout = ctx->Attr("p_dropout"); + const float softmax_scale = ctx->Attr("softmax_scale"); + bool is_causal = ctx->Attr("is_causal"); + int window_size_left = ctx->Attr("window_size_left"); + int window_size_right = ctx->Attr("window_size_right"); + + Tensor* grad_q = ctx->Tensor4ArgNameAndIndex("grad_q", 0); + Tensor* grad_k = ctx->Tensor4ArgNameAndIndex("grad_k", 0); + Tensor* grad_v = ctx->Tensor4ArgNameAndIndex("grad_v", 0); + Tensor* tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + void* tmp_ptr = tmp->mut_dptr(); + + auto* cuda_device = dynamic_cast(ctx->stream()->device()); + auto dprops = cuda_device->properties(); + auto* cuda_stream = ctx->stream()->As(); + + bool is_dropout = p_dropout > 0.0f; + + if (is_causal) { window_size_right = 0; } + + const int arch = cuda_stream->cuda_arch() / 10; + const bool is_supported_arch = (arch == 80 || arch == 86 || arch == 89 || arch == 90); + CHECK(is_supported_arch); + + const DataType data_type = query->data_type(); + const bool is_supported_dtype = + (data_type == DataType::kFloat16 || data_type == DataType::kBFloat16); + CHECK(is_supported_dtype); + CHECK_EQ(key->data_type(), data_type); + CHECK_EQ(value->data_type(), data_type); + CHECK_EQ(grad_out->data_type(), data_type); + CHECK_EQ(out->data_type(), data_type); + CHECK_EQ(softmax_lse->data_type(), DataType::kFloat); + CHECK_EQ(rng_state->data_type(), DataType::kUInt64); + + // check contiguous last dimension. + CHECK_EQ(CHECK_JUST(VectorAt(grad_out->stride(), 3)), 1); + CHECK_EQ(CHECK_JUST(VectorAt(query->stride(), 3)), 1); + CHECK_EQ(CHECK_JUST(VectorAt(key->stride(), 3)), 1); + CHECK_EQ(CHECK_JUST(VectorAt(value->stride(), 3)), 1); + CHECK_EQ(CHECK_JUST(VectorAt(out->stride(), 3)), 1); + + const int batch_size = query->shape_view().At(0); + const int seqlen_q = query->shape_view().At(1); + const int num_heads = query->shape_view().At(2); + const int head_size = query->shape_view().At(3); + const int seqlen_k = key->shape_view().At(1); + const int num_heads_k = key->shape_view().At(2); + const int head_size_og = grad_out->shape_view().At(3); + + // check tensor shape. + CHECK_EQ(grad_out->shape_view().At(0), batch_size); + CHECK_EQ(grad_out->shape_view().At(1), seqlen_q); + CHECK_EQ(grad_out->shape_view().At(2), num_heads); + CHECK_EQ(grad_out->shape_view().At(3), head_size_og); + CHECK_EQ(query->shape_view().At(0), batch_size); + CHECK_EQ(query->shape_view().At(1), seqlen_q); + CHECK_EQ(query->shape_view().At(2), num_heads); + CHECK_EQ(query->shape_view().At(3), head_size); + CHECK_EQ(key->shape_view().At(0), batch_size); + CHECK_EQ(key->shape_view().At(1), seqlen_k); + CHECK_EQ(key->shape_view().At(2), num_heads_k); + CHECK_EQ(key->shape_view().At(3), head_size); + CHECK_EQ(value->shape_view().At(0), batch_size); + CHECK_EQ(value->shape_view().At(1), seqlen_k); + CHECK_EQ(value->shape_view().At(2), num_heads_k); + CHECK_EQ(value->shape_view().At(3), head_size); + CHECK_EQ(out->shape_view().At(0), batch_size); + CHECK_EQ(out->shape_view().At(1), seqlen_q); + CHECK_EQ(out->shape_view().At(2), num_heads); + CHECK_EQ(out->shape_view().At(3), head_size); + CHECK_EQ(softmax_lse->shape_view().At(0), batch_size); + CHECK_EQ(softmax_lse->shape_view().At(1), num_heads); + CHECK_EQ(softmax_lse->shape_view().At(2), seqlen_q); + + CHECK_GT(batch_size, 0); // batch size must be postive + CHECK_LE(head_size, 256); // only support head dimensions at most 256 + // FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout + // requires A100/A800 or H100/H800 + if (head_size > 192 && (head_size <= 224 || is_dropout)) { CHECK((arch == 80 || arch == 90)); } + CHECK(num_heads % num_heads_k + == 0); // Number of heads in key/value must devide number of heads in query + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // bool loop = seqlen_k > blocksize_c; + // TODO: change later, for now set to true for simplicity + bool loop = true; + + // size: batch_size x num_heads x seqlen_q_rounded; datatype: float + void* softmax_d_ptr = tmp_ptr; + tmp_ptr = reinterpret_cast(tmp_ptr) + + GetCudaAlignedSize(batch_size * num_heads * seqlen_q_rounded + * GetSizeOfDataType(DataType::kFloat)); + + // set to false by default. + // TODO(chende): can get from forward kernel(add input in python interface, it's only used for + // backward). + bool deterministic = false; + + void* dq_accum_ptr; + if (loop) { + // size: batch_size x seqlen_q_rounded x num_heads x head_size_rounded; datatype: float + dq_accum_ptr = tmp_ptr; + } + + Flash_bwd_params params; + + set_params_dgrad(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, head_size, head_size_rounded, query, key, value, out, + grad_out, grad_q, grad_k, grad_v, nullptr, nullptr, + loop ? dq_accum_ptr : nullptr, + // loop ? dk_accum.data_ptr() : nullptr, + // loop ? dv_accum.data_ptr() : nullptr, + nullptr, nullptr, const_cast(softmax_lse->dptr()), softmax_d_ptr, + p_dropout, softmax_scale, window_size_left, window_size_right, deterministic); + + params.dq_accum_split_stride = + !deterministic ? 0 : seqlen_q_rounded * num_heads * head_size_rounded; + + auto launch = &run_mha_bwd; + + params.rng_state = const_cast(rng_state->dptr()); + + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + if (seqlen_q > 0) { launch(params, cuda_stream->cuda_stream()); } + } + + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(dtype) \ + REGISTER_USER_KERNEL("scaled_dot_product_flash_attention_grad") \ + .SetCreateFn() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("out", 0) == dtype)) \ + .SetInferTmpSizeFn(InferTmpBufferSizeForFlashAttentionGradKernel); + +REGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(DataType::kFloat16) +REGISTER_SCALED_DOT_PRODUCT_FLASH_ATTENTION_KERNEL(DataType::kBFloat16) + +} // namespace + +} // namespace user_op + +} // namespace oneflow + +#endif // WITH_CUTLASS + +#endif // CUDA_VERSION >= 11070 diff --git a/oneflow/user/kernels/scaled_dot_product_attention_util.h b/oneflow/user/kernels/scaled_dot_product_attention_util.h index 29d93c01da5..a836467318c 100644 --- a/oneflow/user/kernels/scaled_dot_product_attention_util.h +++ b/oneflow/user/kernels/scaled_dot_product_attention_util.h @@ -132,6 +132,54 @@ void set_params_fprop(Flash_fwd_params& params, #endif } +void set_params_dgrad(Flash_bwd_params& params, + // sizes + const size_t b, const size_t seqlen_q, const size_t seqlen_k, + const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, const size_t h, + const size_t h_k, const size_t d, const size_t d_rounded, + // device pointers + const Tensor* q, const Tensor* k, const Tensor* v, const Tensor* out, + const Tensor* dout, Tensor* dq, Tensor* dk, Tensor* dv, void* cu_seqlens_q_d, + void* cu_seqlens_k_d, void* dq_accum_d, void* dk_accum_d, void* dv_accum_d, + void* softmax_lse_d, void* dsoftmax_sum_d, float p_dropout, + float softmax_scale, int window_size_left, int window_size_right, + bool deterministic) { + set_params_fprop(params, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, + d_rounded, q, k, v, const_cast(out), cu_seqlens_q_d, cu_seqlens_k_d, + nullptr, nullptr, softmax_lse_d, p_dropout, softmax_scale, window_size_left, + window_size_right); + + // Set the pointers and strides. + params.do_ptr = const_cast(dout->dptr()); + params.do_row_stride = CHECK_JUST(VectorAt(dout->stride(), 1)); + params.do_head_stride = CHECK_JUST(VectorAt(dout->stride(), 2)); + params.dq_ptr = dq->mut_dptr(); + params.dk_ptr = dk->mut_dptr(); + params.dv_ptr = dv->mut_dptr(); + params.dq_row_stride = CHECK_JUST(VectorAt(dq->stride(), 1)); + params.dk_row_stride = CHECK_JUST(VectorAt(dk->stride(), 1)); + params.dv_row_stride = CHECK_JUST(VectorAt(dv->stride(), 1)); + params.dq_head_stride = CHECK_JUST(VectorAt(dq->stride(), 2)); + params.dk_head_stride = CHECK_JUST(VectorAt(dk->stride(), 2)); + params.dv_head_stride = CHECK_JUST(VectorAt(dv->stride(), 2)); + + if (cu_seqlens_q_d == nullptr) { + params.do_batch_stride = CHECK_JUST(VectorAt(dout->stride(), 0)); + params.dq_batch_stride = CHECK_JUST(VectorAt(dq->stride(), 0)); + params.dk_batch_stride = CHECK_JUST(VectorAt(dk->stride(), 0)); + params.dv_batch_stride = CHECK_JUST(VectorAt(dv->stride(), 0)); + } + + params.dq_accum_ptr = dq_accum_d; + params.dk_accum_ptr = dk_accum_d; + params.dv_accum_ptr = dv_accum_d; + + // Softmax sum + params.dsoftmax_sum = dsoftmax_sum_d; + + params.deterministic = deterministic; +} + void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { @@ -144,6 +192,12 @@ void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split }); } +void run_mha_bwd(Flash_bwd_params& params, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { run_mha_bwd_(params, stream); }); + }); +} + // Find the number of splits that maximizes the occupancy. For example, if we have // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is // better than having 3 splits (efficiency = 0.67). However, we also don't want too many diff --git a/oneflow/user/ops/scaled_dot_product_flash_attention_op.cpp b/oneflow/user/ops/scaled_dot_product_flash_attention_op.cpp index 7e0ebc0102a..8a45c719b88 100644 --- a/oneflow/user/ops/scaled_dot_product_flash_attention_op.cpp +++ b/oneflow/user/ops/scaled_dot_product_flash_attention_op.cpp @@ -115,4 +115,120 @@ Maybe ScaledDotProductFlashAttentionOp::InferDataType(user_op::InferContex return Maybe::Ok(); } +Maybe ScaledDotProductFlashAttentionGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const Shape& dout_shape = ctx->InputShape("grad_out", 0); + const Shape& q_shape = ctx->InputShape("query", 0); + const Shape& k_shape = ctx->InputShape("key", 0); + const Shape& v_shape = ctx->InputShape("value", 0); + const Shape& out_shape = ctx->InputShape("out", 0); + const Shape& softmax_lse_shape = ctx->InputShape("softmax_lse", 0); + + auto batch_size = q_shape.At(0); + auto seqlen_q = q_shape.At(1); + auto num_heads = q_shape.At(2); + auto head_size = q_shape.At(3); + auto seqlen_k = k_shape.At(1); + auto num_heads_k = k_shape.At(2); + auto head_size_og = dout_shape.At(3); + + // check input tensor shape. + CHECK_EQ_OR_RETURN(batch_size, k_shape.At(0)) << "query has different batch size from key."; + CHECK_EQ_OR_RETURN(batch_size, v_shape.At(0)) << "query has different batch size from value."; + CHECK_EQ_OR_RETURN(batch_size, dout_shape.At(0)) + << "query has different batch size from grad_out."; + CHECK_EQ_OR_RETURN(batch_size, out_shape.At(0)) << "query has different batch size from out."; + CHECK_EQ_OR_RETURN(batch_size, softmax_lse_shape.At(0)) + << "query has different batch size from softmax_lse."; + + CHECK_EQ_OR_RETURN(seqlen_k, v_shape.At(1)) << "key has different seqlen from value."; + CHECK_EQ_OR_RETURN(num_heads_k, v_shape.At(2)) << "key has different num_heads from value."; + + // dout should be padded in functional layer if needed. + CHECK_EQ_OR_RETURN(head_size_og, head_size) << "grad_out has different head_size from query"; + CHECK_EQ_OR_RETURN(head_size, k_shape.At(3)) << "query has different head_size from key"; + CHECK_EQ_OR_RETURN(head_size, v_shape.At(3)) << "query has different head_size from value"; + + // batch size must be positive. + CHECK_GT_OR_RETURN(batch_size, 0) << "batch size must be positive"; + + // only support head dimensions at most 256. + CHECK_LE_OR_RETURN(head_size_og, 256) << "only support head dimensions at most 256"; + + CHECK_EQ_OR_RETURN(num_heads % num_heads_k, 0) + << "number of heads in key/value must devide number of heads in query."; + + // grad_k/v should be expanded if needed(when num_heads != num_heads_k && num_heads % num_heads_k + // == 0). + ctx->SetOutputShape("grad_q", 0, Shape({batch_size, seqlen_q, num_heads, head_size})); + ctx->SetOutputShape("grad_k", 0, Shape({batch_size, seqlen_k, num_heads, head_size})); + ctx->SetOutputShape("grad_v", 0, Shape({batch_size, seqlen_k, num_heads, head_size})); + + return Maybe::Ok(); +} + +Maybe ScaledDotProductFlashAttentionGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return ScaledDotProductFlashAttentionGradOp::InferLogicalTensorDesc(ctx); +} + +Maybe ScaledDotProductFlashAttentionGradOp::GetSbp(user_op::SbpContext* ctx) { + auto parallel_num = ctx->parallel_num(); + const Shape& q_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("query", 0).shape(); + const Shape& k_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("key", 0).shape(); + auto num_heads = q_shape.At(2); + auto num_heads_k = k_shape.At(2); + bool can_spilt_num_heads = + num_heads == num_heads_k || (!(num_heads % parallel_num) && !(num_heads_k % parallel_num)); + if (can_spilt_num_heads) { + // prior to split on num_heads. + ctx->NewBuilder() + .Split(user_op::OpArg("grad_out", 0), 2) + .Split(user_op::OpArg("query", 0), 2) + .Split(user_op::OpArg("key", 0), 2) + .Split(user_op::OpArg("value", 0), 2) + .Split(user_op::OpArg("out", 0), 2) + .Split(user_op::OpArg("softmax", 0), 1) + .Broadcast(user_op::OpArg("rng_state", 0)) + .Split(user_op::OpArg("grad_q", 0), 2) + .Split(user_op::OpArg("grad_k", 0), 2) + .Split(user_op::OpArg("grad_v", 0), 2) + .Build(); + } else { + // otherwise split on batch_size. + ctx->NewBuilder() + .Split(user_op::OpArg("grad_out", 0), 0) + .Split(user_op::OpArg("query", 0), 0) + .Split(user_op::OpArg("key", 0), 0) + .Split(user_op::OpArg("value", 0), 0) + .Split(user_op::OpArg("out", 0), 0) + .Split(user_op::OpArg("softmax", 0), 0) + .Broadcast(user_op::OpArg("rng_state", 0)) + .Split(user_op::OpArg("grad_q", 0), 0) + .Split(user_op::OpArg("grad_k", 0), 0) + .Split(user_op::OpArg("grad_v", 0), 0) + .Build(); + } + return Maybe::Ok(); +} + +Maybe ScaledDotProductFlashAttentionGradOp::InferDataType(user_op::InferContext* ctx) { + auto dout_datatype = ctx->InputDType("grad_out", 0); + auto q_datatype = ctx->InputDType("query", 0); + auto k_datatype = ctx->InputDType("key", 0); + auto v_datatype = ctx->InputDType("value", 0); + auto out_datatype = ctx->InputDType("out", 0); + + CHECK_EQ_OR_RETURN(q_datatype, k_datatype) << "query has different data type from key."; + CHECK_EQ_OR_RETURN(q_datatype, v_datatype) << "query has different data type from value."; + CHECK_EQ_OR_RETURN(q_datatype, dout_datatype) << "query has different data type from grad_out."; + CHECK_EQ_OR_RETURN(q_datatype, out_datatype) << "query has different data type from out."; + + ctx->SetOutputDType("grad_q", 0, q_datatype); + ctx->SetOutputDType("grad_k", 0, q_datatype); + ctx->SetOutputDType("grad_v", 0, q_datatype); + + return Maybe::Ok(); +} + } // namespace oneflow diff --git a/python/oneflow/test/modules/test_scaled_dot_product_attention.py b/python/oneflow/test/modules/test_scaled_dot_product_attention.py index 51863b79643..08bb9e45a2b 100644 --- a/python/oneflow/test/modules/test_scaled_dot_product_attention.py +++ b/python/oneflow/test/modules/test_scaled_dot_product_attention.py @@ -44,42 +44,69 @@ def _test_scaled_dot_product_attention( num_heads_k = num_head_pair[1] seq_len_q = seq_len_pair[0] seq_len_kv = seq_len_pair[1] - query = flow.randn( - (batch_size, num_heads, seq_len_q, head_size), device="cuda", dtype=flow.float, - ).to(dtype) - key = flow.randn( - (batch_size, num_heads_k, seq_len_kv, head_size), - device="cuda", - dtype=flow.float, - ).to(dtype) - value = flow.randn( - (batch_size, num_heads_k, seq_len_kv, head_size), - device="cuda", - dtype=flow.float, - ).to(dtype) - - fused_out = ( - flow._C.scaled_dot_product_attention(query=query, key=key, value=value,) - .cpu() - .numpy() + query_raw = np.random.uniform( + low=-1, high=1, size=(batch_size, num_heads, seq_len_q, head_size) + ) + key_raw = np.random.uniform( + low=-1, high=1, size=(batch_size, num_heads_k, seq_len_kv, head_size) + ) + value_raw = np.random.uniform( + low=-1, high=1, size=(batch_size, num_heads_k, seq_len_kv, head_size) + ) + query_fused = flow.tensor(query_raw, dtype=dtype, device="cuda", requires_grad=True) + query_ref = flow.tensor(query_raw, dtype=dtype, device="cuda", requires_grad=True) + key_fused = flow.tensor(key_raw, dtype=dtype, device="cuda", requires_grad=True) + key_ref = flow.tensor(key_raw, dtype=dtype, device="cuda", requires_grad=True) + value_fused = flow.tensor(value_raw, dtype=dtype, device="cuda", requires_grad=True) + value_ref = flow.tensor(value_raw, dtype=dtype, device="cuda", requires_grad=True) + + fused_out = flow._C.scaled_dot_product_attention( + query=query_fused, key=key_fused, value=value_fused, ) if num_heads == num_heads_k: - ref_out = _scaled_dot_product_attention(query, key, value,).cpu().numpy() + ref_out = _scaled_dot_product_attention(query_ref, key_ref, value_ref,) else: # For GQA - ref_out = flow.empty(query.shape, device="cuda", dtype=dtype) + ref_out = flow.empty(query_fused.shape, device="cuda", dtype=dtype) stride = num_heads / num_heads_k for i in range(0, num_heads): j = int(i / stride) ref_out[:, i, :, :] = _scaled_dot_product_attention( - query[:, i, :, :], key[:, j, :, :], value[:, j, :, :] + query_ref[:, i, :, :], key_ref[:, j, :, :], value_ref[:, j, :, :] ) + total_out = ref_out.sum() + fused_out.sum() + total_out.backward() if dtype == flow.float16: - test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2)) + error_tol = 1e-2 elif dtype == flow.bfloat16: - test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-1, rtol=1e-1)) + error_tol = 1e-1 else: - test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-3, rtol=1e-3)) + error_tol = 1e-3 + + test_case.assertTrue( + np.allclose(ref_out.numpy(), fused_out.numpy(), atol=error_tol, rtol=error_tol) + ) + test_case.assertTrue( + np.allclose( + query_fused.grad.numpy(), + query_ref.grad.numpy(), + atol=error_tol, + rtol=error_tol, + ) + ) + test_case.assertTrue( + np.allclose( + key_fused.grad.numpy(), key_ref.grad.numpy(), atol=error_tol, rtol=error_tol + ) + ) + test_case.assertTrue( + np.allclose( + value_fused.grad.numpy(), + value_ref.grad.numpy(), + atol=error_tol, + rtol=error_tol, + ) + ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") diff --git a/python/setup.py b/python/setup.py index 93b600e50dd..54adeaa904b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -63,14 +63,14 @@ def get_version(): ONEFLOW_VERSION = get_version() if "cu11" in ONEFLOW_VERSION and "cu112" not in ONEFLOW_VERSION: - REQUIRED_PACKAGES.append("nvidia-cudnn-cu11") + REQUIRED_PACKAGES.append("nvidia-cudnn-cu11>=8.9,<9.0") REQUIRED_PACKAGES.append("nvidia-cublas-cu11") REQUIRED_PACKAGES.append("nvidia-nccl-cu11") REQUIRED_PACKAGES.append("nvidia-cusparse-cu11") REQUIRED_PACKAGES.append("nvidia-cufft-cu11") if "cu12" in ONEFLOW_VERSION: - REQUIRED_PACKAGES.append("nvidia-cudnn-cu12") + REQUIRED_PACKAGES.append("nvidia-cudnn-cu12>=8.9,<9.0") REQUIRED_PACKAGES.append("nvidia-cublas-cu12") REQUIRED_PACKAGES.append("nvidia-nccl-cu12") REQUIRED_PACKAGES.append("nvidia-cusparse-cu12")