Skip to content

Commit 66d91d5

Browse files
committed
[paddle] add documentation (#489)
* paddle documentation Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * minor fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * review comments Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
1 parent d58c08c commit 66d91d5

File tree

11 files changed

+288
-31
lines changed

11 files changed

+288
-31
lines changed

docs/api/framework.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ Framework-specific API
1010

1111
pytorch
1212
jax
13+
paddle

docs/api/paddle.rst

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
..
2+
Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
4+
See LICENSE for license information.
5+
6+
paddle
7+
======
8+
9+
.. autoapiclass:: transformer_engine.paddle.Linear(in_features, out_features, **kwargs)
10+
:members: forward
11+
12+
.. autoapiclass:: transformer_engine.paddle.LayerNorm(hidden_size, eps=1e-5, **kwargs)
13+
14+
.. autoapiclass:: transformer_engine.paddle.LayerNormLinear(in_features, out_features, eps=1e-5, **kwargs)
15+
:members: forward
16+
17+
.. autoapiclass:: transformer_engine.paddle.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, **kwargs)
18+
:members: forward
19+
20+
.. autoapiclass:: transformer_engine.paddle.FusedScaleMaskSoftmax(attn_mask_type, mask_func, **kwargs)
21+
:members: forward
22+
23+
.. autoapiclass:: transformer_engine.paddle.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
24+
:members: forward
25+
26+
.. autoapiclass:: transformer_engine.paddle.MultiHeadAttention(hidden_size, num_attention_heads, **kwargs)
27+
:members: forward
28+
29+
.. autoapiclass:: transformer_engine.paddle.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
30+
:members: forward
31+
32+
.. autoapifunction:: transformer_engine.paddle.fp8_autocast
33+
34+
.. autoapifunction:: transformer_engine.paddle.recompute

transformer_engine/paddle/fp8.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
from .constants import dist_group_type
1616
from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer, FP8RecomputeBuffer
1717

18+
19+
__all__ = ['fp8_autocast']
20+
21+
1822
# FP8 support
1923
_is_fp8_available = None
2024
_reason_for_no_fp8 = ""
@@ -166,6 +170,40 @@ def fp8_autocast(
166170
) -> None:
167171
"""
168172
Context manager for FP8 usage.
173+
174+
.. code-block:: python
175+
176+
with fp8_autocast(enabled=True):
177+
out = model(inp)
178+
179+
.. note::
180+
181+
Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors
182+
with shapes where both dimensions are divisible by 16. In terms of the input to the full
183+
Transformer network, this typically requires padding sequence length to be multiple of 16.
184+
185+
.. note::
186+
187+
When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once
188+
inside a single `fp8_autocast` region. This is unsupported behavior because the amax
189+
reduction is handled during the exit of the `fp8_autocast` context. Calling the same
190+
module more than once inside an `fp8_autocast` region overrides the amax tensors
191+
before reduction can occur.
192+
193+
Parameters
194+
----------
195+
enabled: bool, default = `False`
196+
whether or not to enable fp8
197+
calibrating: bool, default = `False`
198+
calibration mode allows collecting statistics such as amax and scale
199+
data of fp8 tensors even when executing without fp8 enabled. This is
200+
useful for saving an inference ready fp8 checkpoint while training
201+
using a higher precision.
202+
fp8_recipe: recipe.DelayedScaling, default = `None`
203+
recipe used for FP8 training.
204+
fp8_group: paddle.distributed.collective.Group, default = `None`
205+
distributed group over which amaxes for the fp8 tensors
206+
are reduced at the end of each training step.
169207
"""
170208
try:
171209
_global_fp8_state.enter(enabled, calibrating, fp8_recipe, fp8_group)

transformer_engine/paddle/layer/attention.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
from ..recompute import recompute
3030

3131

32+
__all__ = ["DotProductAttention", "MultiHeadAttention"]
33+
34+
3235
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
3336
"""Function for FusedAttention with packed QKV input"""
3437

@@ -129,7 +132,7 @@ def backward(ctx, d_out):
129132

130133

131134
class DotProductAttention(paddle.nn.Layer):
132-
"""Dot Product Attention Layer
135+
"""
133136
Allows the model to jointly attend to information from different
134137
representation subspaces as described in the paper:
135138
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
@@ -150,8 +153,7 @@ class DotProductAttention(paddle.nn.Layer):
150153
attention_type: {'self', 'cross'}, default = `self`
151154
type of attention operation.
152155
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
153-
backend to use for attention operation.
154-
156+
backend to use for attention operation.
155157
"""
156158

157159
def __init__(self,
@@ -215,17 +217,17 @@ def forward(
215217
Parameters
216218
----------
217219
query_layer : paddle.Tensor
218-
Query tensor.
220+
Query tensor.
219221
key_value_layer : paddle.Tensor
220-
Key tensor.
222+
Key tensor.
221223
attention_mask : Optional[paddle.Tensor], default = `None`
222-
Boolean tensor used to mask out softmax input when not using attention.
224+
Boolean tensor used to mask out softmax input when not using attention.
223225
core_attention_bias_type: str, default = `no_bias`
224-
only support no_bias type currently, {`no_bias`}
226+
only support no_bias type currently, {`no_bias`}
225227
core_attention_bias: Optional[paddle.Tensor], default = `None`
226-
Bias tensor for Q * K.T
227-
set_zero: bool, defautl = `True`
228-
Whether to use the fast path to set output tensors to 0 or not.
228+
Bias tensor for Q * K.T
229+
set_zero: bool, default = `True`
230+
Whether to use the fast path to set output tensors to 0 or not.
229231
"""
230232

231233
backend = self.backend
@@ -358,7 +360,9 @@ def _pd_forward(
358360

359361

360362
class MultiHeadAttention(paddle.nn.Layer):
361-
"""Attention w/ QKV and Proj Gemms
363+
"""
364+
Multi-head Attention (MHA), including Query,
365+
Key, Value and Output projection.
362366
363367
Parameters
364368
----------
@@ -387,7 +391,8 @@ class MultiHeadAttention(paddle.nn.Layer):
387391
zero_centered_gamma: bool, default = `False`
388392
whether to zero initialize the gamma of the layernorm operation.
389393
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
390-
backend to use for attention operation.
394+
backend to use for attention operation. If set to 'paddle', a framework
395+
only no-FP8 path is executed with limited optimization.
391396
392397
Parallelism parameters
393398
----------------------
@@ -542,7 +547,6 @@ def forward(
542547
"""
543548
MultiHeadAttention Layer.
544549
545-
546550
Parameters
547551
----------
548552
hidden_states : paddle.Tensor
@@ -555,7 +559,7 @@ def forward(
555559
only support no_bias type currently, {`no_bias`}
556560
core_attention_bias: Optional[paddle.Tensor], default = `None`
557561
Bias tensor for Q * K.T
558-
set_zero: bool, defautl = `True`
562+
set_zero: bool, default = `True`
559563
Whether to use the fast path to set output tensors to 0 or not.
560564
recompute_core_attention: bool, default = `False`
561565
If true, forward activations for core attention are recomputed

transformer_engine/paddle/layer/layernorm.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,33 @@ def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None
6363
class LayerNorm(paddle.nn.Layer):
6464
r"""
6565
Applies Layer Normalization over a mini-batch of inputs as described in
66-
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`
66+
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
67+
68+
.. math::
69+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
70+
71+
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
72+
size :attr:`hidden_size`
73+
74+
Parameters
75+
----------
76+
hidden_size : int
77+
size of each input sample.
78+
eps : float, default = 1e-5
79+
a value added to the denominator of layer normalization for numerical stability.
80+
weight_attr: Union[paddle.ParamAttr, None], default = None
81+
optional `paddle.ParamAttr` for weight.
82+
bias_attr: Union[paddle.ParamAttr, None, bool], default = None
83+
optional `paddle.ParamAttr` for bias.
84+
zero_centered_gamma : bool, default = 'False'
85+
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
86+
the LayerNorm formula changes to
87+
88+
.. math::
89+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
90+
(1 + \gamma) + \beta
91+
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
92+
backend to use for softmax operation.
6793
"""
6894

6995
def __init__(

transformer_engine/paddle/layer/layernorm_linear.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
saved_tensor_allow_none,
4141
)
4242

43-
__all__ = ["LayerNormLinear", "_layernorm_fwd_fp8_cast", "_layernorm_bwd"]
43+
__all__ = ["LayerNormLinear"]
4444

4545

4646
def _layernorm_fwd_fp8_cast(
@@ -331,6 +331,42 @@ def backward(
331331
class LayerNormLinear(TransformerEngineBaseLayer):
332332
r"""
333333
Applies layer normalization followed by linear transformation to the incoming data.
334+
335+
Parameters
336+
----------
337+
in_features : int
338+
size of each input sample.
339+
out_features : int
340+
size of each output sample.
341+
eps : float, default = 1e-5
342+
a value added to the denominator of layer normalization for numerical stability.
343+
weight_attr: Union[paddle.ParamAttr, None], default = None
344+
optional `paddle.ParamAttr` for weight.
345+
bias_attr: Union[paddle.ParamAttr, None, bool], default = None
346+
optional `paddle.ParamAttr` for bias.
347+
return_layernorm_output : bool, default = `False`
348+
if set to `True`, output of layernorm is returned from the forward
349+
together with the output of the linear transformation.
350+
Example use case: residual connection for transformer module is
351+
taken post layernorm.
352+
zero_centered_gamma : bool, default = 'False'
353+
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
354+
the LayerNorm formula changes to
355+
356+
.. math::
357+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
358+
(1 + \gamma) + \beta
359+
backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
360+
if set to 'paddle', a framework only no-FP8 path is executed with limited optimization.
361+
362+
Parallelism parameters
363+
----------------------
364+
tp_group : ProcessGroup, default = `None`
365+
tensor parallel process group.
366+
parallel_mode : {None, 'Column', 'Row'}, default = `None`
367+
used to decide whether this Linear layer is Column Parallel Linear or Row
368+
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
369+
When set to `None`, no communication is performed.
334370
"""
335371

336372
def __init__(
@@ -503,7 +539,14 @@ def _pd_forward(
503539
return out
504540

505541
def forward(self, *args, **kwargs):
506-
"""forward"""
542+
"""
543+
Apply layer normalization to the input followed by a linear transformation.
544+
545+
Parameters
546+
----------
547+
inp : torch.Tensor
548+
Input tensor.
549+
"""
507550
if self.backend == 'transformer_engine':
508551
return self._te_forward(*args, **kwargs)
509552
if self.backend == 'paddle':

transformer_engine/paddle/layer/layernorm_mlp.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
saved_tensor_allow_none,
4040
)
4141

42+
4243
__all__ = ["LayerNormMLP"]
4344

4445

@@ -549,7 +550,47 @@ def backward(
549550

550551
class LayerNormMLP(TransformerEngineBaseLayer):
551552
r"""
552-
Applies layer normalization followed by linear transformation to the incoming data.
553+
Applies layer normalization on the input followed by the MLP module, consisting of
554+
2 successive linear transformations, separated by the GeLU activation.
555+
556+
Parameters
557+
----------
558+
hidden_size : int
559+
size of each input sample.
560+
ffn_hidden_size : int
561+
intermediate size to which input samples are projected.
562+
eps : float, default = 1e-5
563+
a value added to the denominator of layer normalization for numerical stability.
564+
weight_attr: Union[paddle.ParamAttr, None], default = None
565+
optional `paddle.ParamAttr` for weight.
566+
bias_attr: Union[paddle.ParamAttr, None, bool], default = None
567+
optional `paddle.ParamAttr` for bias.
568+
activation : str, default = 'gelu'
569+
activation function used.
570+
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'.
571+
return_layernorm_output : bool, default = `False`
572+
if set to `True`, output of layernorm is returned from the forward
573+
together with the output of the linear transformation.
574+
Example use case: residual connection for transformer module
575+
is taken post layernorm.
576+
zero_centered_gamma : bool, default = 'False'
577+
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
578+
the LayerNorm formula changes to
579+
580+
.. math::
581+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
582+
(1 + \gamma) + \beta
583+
backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
584+
if set to 'paddle', a framework only no-FP8 path is executed with limited optimization.
585+
586+
Parallelism parameters
587+
----------------------
588+
set_parallel_mode : bool, default = `False`
589+
if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row
590+
Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
591+
tp_group : paddle.distributed.collective.Group, default = `None`
592+
tensor parallel process group.
593+
553594
"""
554595

555596
def __init__(
@@ -753,7 +794,14 @@ def _pd_forward(
753794
return out
754795

755796
def forward(self, *args, **kwargs):
756-
"""forward"""
797+
"""
798+
Apply layer normalization to the input followed by a feedforward network (MLP Block).
799+
800+
Parameters
801+
----------
802+
inp : torch.Tensor
803+
Input tensor.
804+
"""
757805
if self.backend == 'transformer_engine':
758806
return self._te_forward(*args, **kwargs)
759807
if self.backend == 'paddle':

0 commit comments

Comments
 (0)