@@ -5538,6 +5538,139 @@ class ScaledDotProductFlashAttentionFunctor {
5538
5538
#endif // CUDA_VERSION >= 11070
5539
5539
};
5540
5540
5541
+ class ScaledDotProductFlashAttentionGradFunctor {
5542
+ public:
5543
+ ScaledDotProductFlashAttentionGradFunctor () {
5544
+ #if CUDA_VERSION >= 11070
5545
+ op_ = CHECK_JUST (one::OpBuilder (" scaled_dot_product_flash_attention_grad" )
5546
+ .Input (" grad_out" )
5547
+ .Input (" query" )
5548
+ .Input (" key" )
5549
+ .Input (" value" )
5550
+ .Input (" out" )
5551
+ .Input (" softmax_lse" )
5552
+ .Input (" rng_state" )
5553
+ .Output (" grad_q" )
5554
+ .Output (" grad_k" )
5555
+ .Output (" grad_v" )
5556
+ .Build ());
5557
+ #endif
5558
+ }
5559
+
5560
+ Maybe<TensorTuple> operator ()(
5561
+ const std::shared_ptr<one::Tensor>& grad_out, const std::shared_ptr<one::Tensor>& query,
5562
+ const std::shared_ptr<one::Tensor>& key, const std::shared_ptr<one::Tensor>& value,
5563
+ const std::shared_ptr<one::Tensor>& out, const std::shared_ptr<one::Tensor>& softmax_lse,
5564
+ const std::shared_ptr<one::Tensor>& rng_state, const float & dropout_p, const bool & is_causal,
5565
+ const float & scale) const {
5566
+ #if CUDA_VERSION >= 11070
5567
+ // grad_out(batch x q_sqe_len x num_heads x head_size)
5568
+ // query (batch x q_seq_len x num_heads x head_size_padded)
5569
+ // key (batch x kv_seq_len x num_heads_k x head_size_padded)
5570
+ // value (batch x kv_seq_len x num_heads_k x head_size_padded)
5571
+ // out (batch x kv_seq_len x num_heads x head_size_padded)
5572
+ // softmax_lse (batch x num_heads x q_seq_len)
5573
+ const auto head_size = grad_out->shape ()->At (3 );
5574
+ const auto head_size_padded = query->shape ()->At (3 );
5575
+ const auto batch_size = query->shape ()->At (0 );
5576
+ const auto seqlen_q = query->shape ()->At (1 );
5577
+ const auto seqlen_k = key->shape ()->At (1 );
5578
+ const auto num_heads = query->shape ()->At (2 );
5579
+ const auto num_heads_k = key->shape ()->At (2 );
5580
+ CHECK_EQ_OR_RETURN (batch_size, key->shape ()->At (0 ))
5581
+ << " key has different batch size from query." ;
5582
+ CHECK_EQ_OR_RETURN (batch_size, value->shape ()->At (0 ))
5583
+ << " value has different batch size from query." ;
5584
+ CHECK_EQ_OR_RETURN (batch_size, grad_out->shape ()->At (0 ))
5585
+ << " grad_out has different batch size from query." ;
5586
+ CHECK_EQ_OR_RETURN (batch_size, out->shape ()->At (0 ))
5587
+ << " out has different batch size from query." ;
5588
+ CHECK_EQ_OR_RETURN (batch_size, softmax_lse->shape ()->At (0 ))
5589
+ << " softmax_lse has different batch size from query." ;
5590
+ CHECK_EQ_OR_RETURN (num_heads, grad_out->shape ()->At (2 ))
5591
+ << " grad_out has different num_heads from query." ;
5592
+ CHECK_EQ_OR_RETURN (num_heads, softmax_lse->shape ()->At (1 ))
5593
+ << " softmax_lse has different num_heads from query." ;
5594
+ CHECK_EQ_OR_RETURN (num_heads_k, value->shape ()->At (2 ))
5595
+ << " value has different num_heads from key." ;
5596
+ CHECK_EQ_OR_RETURN (seqlen_q, grad_out->shape ()->At (1 ))
5597
+ << " grad_out has different seq_len from query." ;
5598
+ CHECK_EQ_OR_RETURN (seqlen_q, softmax_lse->shape ()->At (2 ))
5599
+ << " softmax_lse has different seq_len from query." ;
5600
+ CHECK_EQ_OR_RETURN (head_size_padded, key->shape ()->At (3 ))
5601
+ << " key has different head dims from query." ;
5602
+ CHECK_EQ_OR_RETURN (head_size_padded, value->shape ()->At (3 ))
5603
+ << " key has different head dims from query." ;
5604
+ CHECK_EQ_OR_RETURN (head_size_padded, out->shape ()->At (3 ))
5605
+ << " out has different head dims from query." ;
5606
+
5607
+ bool padded = head_size % 8 ;
5608
+
5609
+ auto grad_out_ = padded ? JUST (pad_last_dim<8 >(grad_out)) : grad_out;
5610
+
5611
+ auto & attrs = THREAD_CACHED_MUTABLE_ATTR_MAP (" p_dropout" , " softmax_scale" , " is_causal" ,
5612
+ " window_size_left" , " window_size_right" );
5613
+ attrs.SetAllAttrs (dropout_p, scale, is_causal, -1 , -1 );
5614
+
5615
+ auto output = std::make_shared<TensorTuple>(3 );
5616
+ auto output_ = JUST (OpInterpUtil::Dispatch<TensorTuple>(
5617
+ *op_, {grad_out_, query, key, value, out, softmax_lse, rng_state}, attrs));
5618
+ CHECK_EQ (output_->size (), 3 );
5619
+ auto grad_q_ = (*output_)[0 ];
5620
+ auto grad_k_ = (*output_)[1 ];
5621
+ auto grad_v_ = (*output_)[2 ];
5622
+
5623
+ // auto grad_q_padded = JUST(functional::Transpose(grad_q_, {0, 2, 1, 3}));
5624
+ // auto grad_k_padded_expanded = JUST(functional::Transpose(grad_k_, {0, 2, 1, 3}));
5625
+ // auto grad_v_padded_expanded = JUST(functional::Transpose(grad_v_, {0, 2, 1, 3}));
5626
+
5627
+ std::shared_ptr<Tensor> grad_q_padded, grad_k_padded, grad_v_padded;
5628
+
5629
+ bool expanded = num_heads != num_heads_k;
5630
+
5631
+ grad_q_padded = grad_q_;
5632
+ if (expanded) {
5633
+ grad_k_padded = JUST (functional::ReduceSum (
5634
+ JUST (functional::Reshape (grad_k_, {batch_size, seqlen_k, num_heads_k,
5635
+ num_heads / num_heads_k, head_size_padded})),
5636
+ {3 }, false , grad_k_->dtype ()));
5637
+ grad_v_padded = JUST (functional::ReduceSum (
5638
+ JUST (functional::Reshape (grad_v_, {batch_size, seqlen_k, num_heads_k,
5639
+ num_heads / num_heads_k, head_size_padded})),
5640
+ {3 }, false , grad_v_->dtype ()));
5641
+ } else {
5642
+ grad_k_padded = grad_k_;
5643
+ grad_v_padded = grad_v_;
5644
+ }
5645
+
5646
+ auto grad_q = padded ? JUST (functional::Slice (grad_q_padded, {0 , 0 , 0 , 0 },
5647
+ {batch_size, seqlen_q, num_heads, head_size},
5648
+ {1 , 1 , 1 , 1 }, false ))
5649
+ : grad_q_padded;
5650
+ auto grad_k = padded ? JUST (functional::Slice (grad_k_padded, {0 , 0 , 0 , 0 },
5651
+ {batch_size, seqlen_k, num_heads_k, head_size},
5652
+ {1 , 1 , 1 , 1 }, false ))
5653
+ : grad_k_padded;
5654
+ auto grad_v = padded ? JUST (functional::Slice (grad_v_padded, {0 , 0 , 0 , 0 },
5655
+ {batch_size, seqlen_k, num_heads_k, head_size},
5656
+ {1 , 1 , 1 , 1 }, false ))
5657
+ : grad_v_padded;
5658
+
5659
+ (*output)[0 ] = grad_q;
5660
+ (*output)[1 ] = grad_k;
5661
+ (*output)[2 ] = grad_v;
5662
+ return output;
5663
+
5664
+ #endif // CUDA_VERSION >= 11070
5665
+
5666
+ UNIMPLEMENTED_THEN_RETURN () << " only support CUDA_VERSION >= 11070." ;
5667
+ }
5668
+
5669
+ private:
5670
+ #if CUDA_VERSION >= 11070
5671
+ std::shared_ptr<OpExpr> op_;
5672
+ #endif // CUDA_VERSION >= 11070
5673
+ };
5541
5674
} // namespace impl
5542
5675
5543
5676
ONEFLOW_FUNCTION_LIBRARY (m) {
@@ -5676,6 +5809,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
5676
5809
m.add_functor <impl::MultiTensorYoloV5WeightUpdateFunctor>(" MultiTensorYoloV5WeightUpdate" );
5677
5810
m.add_functor <impl::FusedClipGradFunctor>(" FusedClipGrad" );
5678
5811
m.add_functor <impl::ScaledDotProductFlashAttentionFunctor>(" ScaledDotProductFlashAttention" );
5812
+ m.add_functor <impl::ScaledDotProductFlashAttentionGradFunctor>(
5813
+ " ScaledDotProductFlashAttentionGrad" );
5679
5814
}
5680
5815
5681
5816
} // namespace functional
0 commit comments