Skip to content

Commit 4df8488

Browse files
Removed the unused options from GroupedLinear docs and fixed the bug with offsets (#1220)
* Removing the unused options from GroupedLinear docs and fixing the bug with offsets Signed-off-by: Przemyslaw Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * offsets -> fp8_meta_offsets Signed-off-by: Przemyslaw Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Przemyslaw Tredak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 458c7de commit 4df8488

File tree

1 file changed

+27
-64
lines changed

1 file changed

+27
-64
lines changed

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 27 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,6 @@
4444

4545
__all__ = ["GroupedLinear"]
4646

47-
"""
48-
The offset for fp8_meta_index.
49-
_GEMM_INPUT = 0
50-
_GEMM_WEIGHT = num_gemms
51-
_GEMM_OUTPUT = 2 * num_gemms
52-
Must be properly set in GroupedLinear's initialization.
53-
"""
54-
_GEMM_INPUT = 0
55-
_GEMM_WEIGHT = 0
56-
_GEMM_OUTPUT = 0
57-
_GRAD_OUTPUT = 0
58-
5947

6048
class _GroupedLinear(torch.autograd.Function):
6149
"""GroupedLinear semi-top level module
@@ -74,12 +62,9 @@ def forward(
7462
fp8_meta: Dict[str, Any],
7563
fuse_wgrad_accumulation: bool,
7664
cpu_offloading: bool,
77-
tp_group: Union[dist_group_type, None],
78-
tp_size: int,
7965
sequence_parallel: bool,
80-
tensor_parallel: bool,
8166
activation_dtype: torch.dtype,
82-
parallel_mode: Union[str, None],
67+
fp8_meta_offsets: Dict[str, int],
8368
is_grad_enabled: bool,
8469
weights_fp8: List[Union[Float8Tensor, None]],
8570
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
@@ -103,7 +88,6 @@ def forward(
10388
inputmats_t = []
10489
inputmat_scale_inv = None
10590

106-
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
10791
if fp8:
10892
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
10993
inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device)
@@ -114,7 +98,9 @@ def forward(
11498
and not sequence_parallel
11599
):
116100
# FP8 input for forward, FP8 input transpose for backward wgrad
117-
indices = list(range(_GEMM_INPUT, _GEMM_INPUT + num_gemms))
101+
indices = list(
102+
range(fp8_meta_offsets["input"], fp8_meta_offsets["input"] + num_gemms)
103+
)
118104
inputmats, inputmats_t = fp8_multi_cast_transpose_fused(
119105
inputmats_no_fp8,
120106
fp8_meta["scaling_fwd"],
@@ -130,7 +116,7 @@ def forward(
130116
cast_to_fp8(
131117
inputmats_no_fp8[i],
132118
fp8_meta["scaling_fwd"],
133-
_GEMM_INPUT + i,
119+
fp8_meta_offsets["input"] + i,
134120
fp8_dtype_forward,
135121
scale_inv=inputmat_scale_inv,
136122
)
@@ -194,14 +180,14 @@ def forward(
194180
for i in range(num_gemms):
195181
# amax of input
196182
amin, amax = inputmats[i].aminmax()
197-
fp8_meta["scaling_fwd"].amax_history[0][_GEMM_INPUT + i] = torch.max(
198-
-amin, amax
199-
).float()
183+
fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["input"] + i] = (
184+
torch.max(-amin, amax).float()
185+
)
200186
# amax of weight
201187
amin, amax = weights[i].aminmax()
202-
fp8_meta["scaling_fwd"].amax_history[0][_GEMM_WEIGHT + i] = torch.max(
203-
-amin, amax
204-
).float()
188+
fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["weight"] + i] = (
189+
torch.max(-amin, amax).float()
190+
)
205191

206192
out = torch.empty(
207193
[sum(m_splits), weights[0].size(0)],
@@ -266,11 +252,8 @@ def forward(
266252
ctx.is_first_microbatch = is_first_microbatch
267253
ctx.use_bias = use_bias
268254
ctx.sequence_parallel = sequence_parallel
269-
ctx.tensor_parallel = tensor_parallel
270255
ctx.inp_shape = inp.shape
271-
ctx.parallel_mode = parallel_mode
272-
ctx.tp_group = tp_group
273-
ctx.tp_size = tp_size
256+
ctx.fp8_meta_offsets = fp8_meta_offsets
274257
ctx.requires_dgrad = inp.requires_grad
275258
ctx.reduce_and_update_bwd_fp8_tensors = False
276259
if ctx.fp8 and requires_grad(inp, weights[0], biases[0]):
@@ -300,7 +283,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
300283
w.main_grad = main_grads[i]
301284
weights[i] = w
302285

303-
global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT
304286
# preprocess grad_output
305287
grad_output = grad_output.contiguous()
306288
grad_output_mats = torch.split(
@@ -318,13 +300,18 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
318300
fp8_cast_transpose_bgrad_fused(
319301
grad_output_mats[i],
320302
ctx.fp8_meta["scaling_bwd"],
321-
_GRAD_OUTPUT + i,
303+
ctx.fp8_meta_offsets["grad_output"] + i,
322304
fp8_dtype_backward,
323305
)
324306
)
325307
else:
326308
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
327-
indices = list(range(_GRAD_OUTPUT, _GRAD_OUTPUT + ctx.num_gemms))
309+
indices = list(
310+
range(
311+
ctx.fp8_meta_offsets["grad_output"],
312+
ctx.fp8_meta_offsets["grad_output"] + ctx.num_gemms,
313+
)
314+
)
328315
grad_output_c, grad_output_t = fp8_multi_cast_transpose_fused(
329316
grad_output_mats,
330317
ctx.fp8_meta["scaling_bwd"],
@@ -338,7 +325,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
338325
grad_output_c[i] = cast_to_fp8(
339326
grad_output_mats[i],
340327
ctx.fp8_meta["scaling_bwd"],
341-
_GRAD_OUTPUT + i,
328+
ctx.fp8_meta_offsets["grad_output"] + i,
342329
fp8_dtype_backward,
343330
)
344331

@@ -363,7 +350,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
363350
weights_fp8[0]._fp8_dtype,
364351
grad_output_c,
365352
ctx.fp8_meta["scaling_bwd"].scale_inv,
366-
_GRAD_OUTPUT,
353+
ctx.fp8_meta_offsets["grad_output"],
367354
fp8_dtype_backward,
368355
[dgrad],
369356
ctx.activation_dtype,
@@ -416,7 +403,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
416403
fp8_dtype_forward,
417404
grad_output_t,
418405
ctx.fp8_meta["scaling_bwd"].scale_inv,
419-
_GRAD_OUTPUT,
406+
ctx.fp8_meta_offsets["grad_output"],
420407
fp8_dtype_backward,
421408
wgrad_list,
422409
ctx.activation_dtype,
@@ -497,12 +484,9 @@ def handle_custom_ddp_from_mcore(w, wgrad):
497484
None, # fp8_meta
498485
None, # fuse_wgrad_accumulation
499486
None, # cpu_offloading
500-
None, # tp_group
501-
None, # tp_size
502487
None, # sequence_parallel
503-
None, # tensor_parallel
504488
None, # activation_dtype
505-
None, # parallel_mode
489+
None, # fp8_meta_offsets
506490
None, # is_grad_enabled
507491
None, # weights_fp8
508492
*wgrad_list,
@@ -536,23 +520,6 @@ class GroupedLinear(TransformerEngineBaseModule):
536520
responsibility to ensure all parameters are moved to the GPU before running the
537521
forward pass.
538522
539-
Parallelism parameters
540-
----------------------
541-
sequence_parallel : bool, default = `False`
542-
if set to `True`, uses sequence parallelism.
543-
tp_group : ProcessGroup, default = `None`
544-
tensor parallel process group.
545-
tp_size : int, default = 1
546-
used as TP (tensor parallel) world size when TP groups are not formed during
547-
initialization. In this case, users must call the
548-
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
549-
forward pass to supply the tensor parallel group needed for tensor and sequence
550-
parallel collectives.
551-
parallel_mode : {None, 'column', 'row'}, default = `None`
552-
used to decide whether this GroupedLinear layer is Column Parallel Linear or Row
553-
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
554-
When set to `None`, no communication is performed.
555-
556523
Optimization parameters
557524
-----------------------
558525
fuse_wgrad_accumulation : bool, default = 'False'
@@ -613,8 +580,7 @@ def __init__(
613580
self.get_rng_state_tracker = get_rng_state_tracker
614581
self.rng_tracker_name = rng_tracker_name
615582

616-
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
617-
_GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, num_gemms, 2 * num_gemms
583+
self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0}
618584

619585
if tp_group is None:
620586
self.tp_size = tp_size
@@ -651,7 +617,7 @@ def __init__(
651617
),
652618
init_fn=init_method,
653619
get_rng_state_tracker=get_rng_state_tracker,
654-
fp8_meta_index=_GEMM_WEIGHT + i,
620+
fp8_meta_index=self._offsets["weight"] + i,
655621
)
656622

657623
# Construct bias parameters if needed
@@ -774,7 +740,7 @@ def forward(
774740
weight_tensors_fp8[i] = self.get_fp8_workspace(
775741
tensor=weight_tensors[i],
776742
fp8_meta_forward=True,
777-
fp8_meta_index=_GEMM_WEIGHT + i,
743+
fp8_meta_index=self._offsets["weight"] + i,
778744
cache_name=(None if is_first_microbatch is None else f"weight{i}"),
779745
update_workspace=update_workspace,
780746
skip_update_flag=skip_fp8_weight_update,
@@ -798,12 +764,9 @@ def forward(
798764
self.fp8_meta,
799765
self.fuse_wgrad_accumulation,
800766
CPUOffloadEnabled,
801-
self.tp_group,
802-
self.tp_size,
803767
self.sequence_parallel,
804-
self.tp_size > 1,
805768
self.activation_dtype,
806-
self.parallel_mode,
769+
self._offsets,
807770
torch.is_grad_enabled(),
808771
weight_tensors_fp8,
809772
*weight_tensors,

0 commit comments

Comments
 (0)