2929from ..recompute import recompute
3030
3131
32+ __all__ = ["DotProductAttention" , "MultiHeadAttention" ]
33+
34+
3235class FusedAttnFuncPackedQKV (paddle .autograd .PyLayer ):
3336 """Function for FusedAttention with packed QKV input"""
3437
@@ -129,7 +132,7 @@ def backward(ctx, d_out):
129132
130133
131134class DotProductAttention (paddle .nn .Layer ):
132- """Dot Product Attention Layer
135+ """
133136 Allows the model to jointly attend to information from different
134137 representation subspaces as described in the paper:
135138 `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
@@ -150,8 +153,7 @@ class DotProductAttention(paddle.nn.Layer):
150153 attention_type: {'self', 'cross'}, default = `self`
151154 type of attention operation.
152155 backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
153- backend to use for attention operation.
154-
156+ backend to use for attention operation.
155157 """
156158
157159 def __init__ (self ,
@@ -215,17 +217,17 @@ def forward(
215217 Parameters
216218 ----------
217219 query_layer : paddle.Tensor
218- Query tensor.
220+ Query tensor.
219221 key_value_layer : paddle.Tensor
220- Key tensor.
222+ Key tensor.
221223 attention_mask : Optional[paddle.Tensor], default = `None`
222- Boolean tensor used to mask out softmax input when not using attention.
224+ Boolean tensor used to mask out softmax input when not using attention.
223225 core_attention_bias_type: str, default = `no_bias`
224- only support no_bias type currently, {`no_bias`}
226+ only support no_bias type currently, {`no_bias`}
225227 core_attention_bias: Optional[paddle.Tensor], default = `None`
226- Bias tensor for Q * K.T
227- set_zero: bool, defautl = `True`
228- Whether to use the fast path to set output tensors to 0 or not.
228+ Bias tensor for Q * K.T
229+ set_zero: bool, default = `True`
230+ Whether to use the fast path to set output tensors to 0 or not.
229231 """
230232
231233 backend = self .backend
@@ -358,7 +360,9 @@ def _pd_forward(
358360
359361
360362class MultiHeadAttention (paddle .nn .Layer ):
361- """Attention w/ QKV and Proj Gemms
363+ """
364+ Multi-head Attention (MHA), including Query,
365+ Key, Value and Output projection.
362366
363367 Parameters
364368 ----------
@@ -387,7 +391,8 @@ class MultiHeadAttention(paddle.nn.Layer):
387391 zero_centered_gamma: bool, default = `False`
388392 whether to zero initialize the gamma of the layernorm operation.
389393 backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
390- backend to use for attention operation.
394+ backend to use for attention operation. If set to 'paddle', a framework
395+ only no-FP8 path is executed with limited optimization.
391396
392397 Parallelism parameters
393398 ----------------------
@@ -542,7 +547,6 @@ def forward(
542547 """
543548 MultiHeadAttention Layer.
544549
545-
546550 Parameters
547551 ----------
548552 hidden_states : paddle.Tensor
@@ -555,7 +559,7 @@ def forward(
555559 only support no_bias type currently, {`no_bias`}
556560 core_attention_bias: Optional[paddle.Tensor], default = `None`
557561 Bias tensor for Q * K.T
558- set_zero: bool, defautl = `True`
562+ set_zero: bool, default = `True`
559563 Whether to use the fast path to set output tensors to 0 or not.
560564 recompute_core_attention: bool, default = `False`
561565 If true, forward activations for core attention are recomputed
0 commit comments