Skip to content

Commit

Permalink
Merge branch 'master' into fused_rope_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
cccddd77 authored May 11, 2024
2 parents 278295d + ea585f6 commit 351e0e5
Show file tree
Hide file tree
Showing 9 changed files with 746 additions and 26 deletions.
112 changes: 112 additions & 0 deletions oneflow/core/autograd/gradient_funcs/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/common/container_util.h"
#if CUDA_VERSION >= 11070

namespace oneflow {

namespace one {

struct ScaledDotProductFlashAttentionCaptureState : public AutoGradCaptureState {
bool query_requires_grad = true;
bool key_requires_grad = true;
bool value_requires_grad = true;
size_t query_idx = 0;
size_t key_idx = 0;
size_t value_idx = 0;
size_t out_idx = 0;
size_t softmax_lse_idx = 0;
size_t rng_state_idx = 0;
float p_dropout = .0f;
float softmax_scale = .0f;
bool is_causal = false;
};

class ScaledDotProductFlashAttention
: public OpExprGradFunction<ScaledDotProductFlashAttentionCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr) << "fw_op_expr should not be None. ";
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> Capture(ScaledDotProductFlashAttentionCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 3) << "Input size should be equal to 3. ";
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->p_dropout = JUST(composed_attrs.GetAttr<float>("p_dropout"));
ctx->softmax_scale = JUST(composed_attrs.GetAttr<float>("softmax_scale"));
ctx->is_causal = JUST(composed_attrs.GetAttr<bool>("is_causal"));
ctx->query_requires_grad = JUST(oneflow::VectorAt(inputs, 0))->requires_grad();
ctx->key_requires_grad = JUST(oneflow::VectorAt(inputs, 1))->requires_grad();
ctx->value_requires_grad = JUST(oneflow::VectorAt(inputs, 2))->requires_grad();
ctx->query_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 0)));
ctx->key_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 1)));
ctx->value_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(inputs, 2)));
ctx->out_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 0)));
ctx->softmax_lse_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 1)));
ctx->rng_state_idx = ctx->SaveTensorForBackward(JUST(oneflow::VectorAt(outputs, 2)));
return Maybe<void>::Ok();
}

Maybe<void> Apply(const ScaledDotProductFlashAttentionCaptureState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 3) << "Out grads size should be equal to 3. ";
std::shared_ptr<oneflow::one::TensorTuple> grads;
in_grads->resize(3);
grads = JUST(functional::ScaledDotProductFlashAttentionGrad(
JUST(oneflow::VectorAt(out_grads, 0)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->query_idx)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->key_idx)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->value_idx)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->out_idx)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->softmax_lse_idx)),
JUST(oneflow::VectorAt(ctx->SavedTensors(), ctx->rng_state_idx)), ctx->p_dropout,
ctx->is_causal, ctx->softmax_scale));

if (ctx->query_requires_grad) {
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(oneflow::VectorAt(*grads, 0));
}
if (ctx->key_requires_grad) {
JUST(oneflow::VectorAt(*in_grads, 1)) = JUST(oneflow::VectorAt(*grads, 1));
}
if (ctx->value_requires_grad) {
JUST(oneflow::VectorAt(*in_grads, 2)) = JUST(oneflow::VectorAt(*grads, 2));
}

return Maybe<void>::Ok();
}

private:
AttrMap base_attrs_;
};

REGISTER_OP_EXPR_GRAD_FUNCTION("scaled_dot_product_flash_attention",
ScaledDotProductFlashAttention);

} // namespace one

} // namespace oneflow

#endif // CUDA_VERSION >= 11070
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2692,6 +2692,10 @@
signature: "Tensor (Tensor query, Tensor key, Tensor value, Tensor attn_mask=None, Float dropout_p=0.0, Bool is_causal=False, Float scale=None, Int64 seed=0) => ScaledDotProductFlashAttention"
bind_python: True

- name: "scaled_dot_product_attention_grad"
signature: "TensorTuple (Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor softmax_lse, Tensor rng_state, Float dropout_p=0.0, Bool is_causal=False, Float scale=0.0) => ScaledDotProductFlashAttentionGrad"
bind_python: False

- name: "fused_multi_head_attention_inference"
signature: "Tensor (Tensor query, Tensor key, Tensor value, Int64 num_heads, Bool causal=False, Int64 query_hidden_slice_start=0, Int64 query_hidden_slice_end=-1, Int64 key_hidden_slice_start=0, Int64 key_hidden_slice_end=-1, Int64 value_hidden_slice_start=0, Int64 value_hidden_slice_end=-1, Tensor attn_bias=None, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInference"
bind_python: True
Expand Down
131 changes: 131 additions & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5538,6 +5538,135 @@ class ScaledDotProductFlashAttentionFunctor {
#endif // CUDA_VERSION >= 11070
};

class ScaledDotProductFlashAttentionGradFunctor {
public:
ScaledDotProductFlashAttentionGradFunctor() {
#if CUDA_VERSION >= 11070
op_ = CHECK_JUST(one::OpBuilder("scaled_dot_product_flash_attention_grad")
.Input("grad_out")
.Input("query")
.Input("key")
.Input("value")
.Input("out")
.Input("softmax_lse")
.Input("rng_state")
.Output("grad_q")
.Output("grad_k")
.Output("grad_v")
.Build());
#endif
}

Maybe<TensorTuple> operator()(
const std::shared_ptr<one::Tensor>& grad_out, const std::shared_ptr<one::Tensor>& query,
const std::shared_ptr<one::Tensor>& key, const std::shared_ptr<one::Tensor>& value,
const std::shared_ptr<one::Tensor>& out, const std::shared_ptr<one::Tensor>& softmax_lse,
const std::shared_ptr<one::Tensor>& rng_state, const float& dropout_p, const bool& is_causal,
const float& scale) const {
#if CUDA_VERSION >= 11070
// grad_out(batch x q_sqe_len x num_heads x head_size)
// query (batch x q_seq_len x num_heads x head_size_padded)
// key (batch x kv_seq_len x num_heads_k x head_size_padded)
// value (batch x kv_seq_len x num_heads_k x head_size_padded)
// out (batch x kv_seq_len x num_heads x head_size_padded)
// softmax_lse (batch x num_heads x q_seq_len)
const auto head_size = grad_out->shape()->At(3);
const auto head_size_padded = query->shape()->At(3);
const auto batch_size = query->shape()->At(0);
const auto seqlen_q = query->shape()->At(1);
const auto seqlen_k = key->shape()->At(1);
const auto num_heads = query->shape()->At(2);
const auto num_heads_k = key->shape()->At(2);
CHECK_EQ_OR_RETURN(batch_size, key->shape()->At(0))
<< " key has different batch size from query.";
CHECK_EQ_OR_RETURN(batch_size, value->shape()->At(0))
<< " value has different batch size from query.";
CHECK_EQ_OR_RETURN(batch_size, grad_out->shape()->At(0))
<< " grad_out has different batch size from query.";
CHECK_EQ_OR_RETURN(batch_size, out->shape()->At(0))
<< " out has different batch size from query.";
CHECK_EQ_OR_RETURN(batch_size, softmax_lse->shape()->At(0))
<< " softmax_lse has different batch size from query.";
CHECK_EQ_OR_RETURN(num_heads, grad_out->shape()->At(2))
<< " grad_out has different num_heads from query.";
CHECK_EQ_OR_RETURN(num_heads, softmax_lse->shape()->At(1))
<< " softmax_lse has different num_heads from query.";
CHECK_EQ_OR_RETURN(num_heads_k, value->shape()->At(2))
<< " value has different num_heads from key.";
CHECK_EQ_OR_RETURN(seqlen_q, grad_out->shape()->At(1))
<< " grad_out has different seq_len from query.";
CHECK_EQ_OR_RETURN(seqlen_q, softmax_lse->shape()->At(2))
<< " softmax_lse has different seq_len from query.";
CHECK_EQ_OR_RETURN(head_size_padded, key->shape()->At(3))
<< " key has different head dims from query.";
CHECK_EQ_OR_RETURN(head_size_padded, value->shape()->At(3))
<< " key has different head dims from query.";
CHECK_EQ_OR_RETURN(head_size_padded, out->shape()->At(3))
<< " out has different head dims from query.";

bool padded = head_size % 8;

auto grad_out_ = padded ? JUST(pad_last_dim<8>(grad_out)) : grad_out;

auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("p_dropout", "softmax_scale", "is_causal",
"window_size_left", "window_size_right");
attrs.SetAllAttrs(dropout_p, scale, is_causal, -1, -1);

auto output = std::make_shared<TensorTuple>(3);
auto output_ = JUST(OpInterpUtil::Dispatch<TensorTuple>(
*op_, {grad_out_, query, key, value, out, softmax_lse, rng_state}, attrs));
CHECK_EQ(output_->size(), 3);
auto grad_q_ = (*output_)[0];
auto grad_k_ = (*output_)[1];
auto grad_v_ = (*output_)[2];

std::shared_ptr<Tensor> grad_q_padded, grad_k_padded, grad_v_padded;

bool expanded = num_heads != num_heads_k;

grad_q_padded = grad_q_;
if (expanded) {
grad_k_padded = JUST(functional::ReduceSum(
JUST(functional::Reshape(grad_k_, {batch_size, seqlen_k, num_heads_k,
num_heads / num_heads_k, head_size_padded})),
{3}, false, grad_k_->dtype()));
grad_v_padded = JUST(functional::ReduceSum(
JUST(functional::Reshape(grad_v_, {batch_size, seqlen_k, num_heads_k,
num_heads / num_heads_k, head_size_padded})),
{3}, false, grad_v_->dtype()));
} else {
grad_k_padded = grad_k_;
grad_v_padded = grad_v_;
}

auto grad_q = padded ? JUST(functional::Slice(grad_q_padded, {0, 0, 0, 0},
{batch_size, seqlen_q, num_heads, head_size},
{1, 1, 1, 1}, false))
: grad_q_padded;
auto grad_k = padded ? JUST(functional::Slice(grad_k_padded, {0, 0, 0, 0},
{batch_size, seqlen_k, num_heads_k, head_size},
{1, 1, 1, 1}, false))
: grad_k_padded;
auto grad_v = padded ? JUST(functional::Slice(grad_v_padded, {0, 0, 0, 0},
{batch_size, seqlen_k, num_heads_k, head_size},
{1, 1, 1, 1}, false))
: grad_v_padded;

(*output)[0] = grad_q;
(*output)[1] = grad_k;
(*output)[2] = grad_v;
return output;

#endif // CUDA_VERSION >= 11070

UNIMPLEMENTED_THEN_RETURN() << "only support CUDA_VERSION >= 11070.";
}

private:
#if CUDA_VERSION >= 11070
std::shared_ptr<OpExpr> op_;
#endif // CUDA_VERSION >= 11070
};
} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Expand Down Expand Up @@ -5676,6 +5805,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::MultiTensorYoloV5WeightUpdateFunctor>("MultiTensorYoloV5WeightUpdate");
m.add_functor<impl::FusedClipGradFunctor>("FusedClipGrad");
m.add_functor<impl::ScaledDotProductFlashAttentionFunctor>("ScaledDotProductFlashAttention");
m.add_functor<impl::ScaledDotProductFlashAttentionGradFunctor>(
"ScaledDotProductFlashAttentionGrad");
}

} // namespace functional
Expand Down
29 changes: 29 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2903,6 +2903,35 @@ def OneFlow_ScaledDotProductFlashAttentionOp : OneFlow_BaseOp<"scaled_dot_produc
let has_data_type_infer_fn = 1;
}

def OneFlow_ScaledDotProductFlashAttentionGradOp : OneFlow_BaseOp<"scaled_dot_product_flash_attention_grad", [NoMemoryEffect, NoGrad, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$grad_out,
OneFlow_Tensor:$query,
OneFlow_Tensor:$key,
OneFlow_Tensor:$value,
OneFlow_Tensor:$out,
OneFlow_Tensor:$softmax_lse,
OneFlow_Tensor:$rng_state,
Optional<OneFlow_Tensor>:$alibi_slopes_
);
let output = (outs
OneFlow_Tensor:$grad_q,
OneFlow_Tensor:$grad_k,
OneFlow_Tensor:$grad_v
);
let attrs = (ins
DefaultValuedAttr<F32Attr, "0.">:$p_dropout,
DefaultValuedAttr<F32Attr, "0.">:$softmax_scale,
DefaultValuedAttr<BoolAttr, "false">:$is_causal,
SI32Attr:$window_size_left,
SI32Attr:$window_size_right
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_FusedMultiHeadAttentionInferenceOp : OneFlow_BaseOp<"fused_multi_head_attention_inference", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$query,
Expand Down
Loading

0 comments on commit 351e0e5

Please sign in to comment.