4444
4545__all__ = ["GroupedLinear" ]
4646
47- """
48- The offset for fp8_meta_index.
49- _GEMM_INPUT = 0
50- _GEMM_WEIGHT = num_gemms
51- _GEMM_OUTPUT = 2 * num_gemms
52- Must be properly set in GroupedLinear's initialization.
53- """
54- _GEMM_INPUT = 0
55- _GEMM_WEIGHT = 0
56- _GEMM_OUTPUT = 0
57- _GRAD_OUTPUT = 0
58-
5947
6048class _GroupedLinear (torch .autograd .Function ):
6149 """GroupedLinear semi-top level module
@@ -74,12 +62,9 @@ def forward(
7462 fp8_meta : Dict [str , Any ],
7563 fuse_wgrad_accumulation : bool ,
7664 cpu_offloading : bool ,
77- tp_group : Union [dist_group_type , None ],
78- tp_size : int ,
7965 sequence_parallel : bool ,
80- tensor_parallel : bool ,
8166 activation_dtype : torch .dtype ,
82- parallel_mode : Union [str , None ],
67+ fp8_meta_offsets : Dict [str , int ],
8368 is_grad_enabled : bool ,
8469 weights_fp8 : List [Union [Float8Tensor , None ]],
8570 * weights_and_biases : Union [Float8Tensor , torch .Tensor , None ],
@@ -103,7 +88,6 @@ def forward(
10388 inputmats_t = []
10489 inputmat_scale_inv = None
10590
106- global _GEMM_INPUT , _GEMM_WEIGHT , _GEMM_OUTPUT
10791 if fp8 :
10892 fp8_dtype_forward = get_fp8_te_dtype (fp8_meta ["recipe" ], fprop_tensor = True )
10993 inputmat_scale_inv = torch .empty ([num_gemms ], dtype = torch .float32 , device = inp .device )
@@ -114,7 +98,9 @@ def forward(
11498 and not sequence_parallel
11599 ):
116100 # FP8 input for forward, FP8 input transpose for backward wgrad
117- indices = list (range (_GEMM_INPUT , _GEMM_INPUT + num_gemms ))
101+ indices = list (
102+ range (fp8_meta_offsets ["input" ], fp8_meta_offsets ["input" ] + num_gemms )
103+ )
118104 inputmats , inputmats_t = fp8_multi_cast_transpose_fused (
119105 inputmats_no_fp8 ,
120106 fp8_meta ["scaling_fwd" ],
@@ -130,7 +116,7 @@ def forward(
130116 cast_to_fp8 (
131117 inputmats_no_fp8 [i ],
132118 fp8_meta ["scaling_fwd" ],
133- _GEMM_INPUT + i ,
119+ fp8_meta_offsets [ "input" ] + i ,
134120 fp8_dtype_forward ,
135121 scale_inv = inputmat_scale_inv ,
136122 )
@@ -194,14 +180,14 @@ def forward(
194180 for i in range (num_gemms ):
195181 # amax of input
196182 amin , amax = inputmats [i ].aminmax ()
197- fp8_meta ["scaling_fwd" ].amax_history [0 ][_GEMM_INPUT + i ] = torch . max (
198- - amin , amax
199- ). float ()
183+ fp8_meta ["scaling_fwd" ].amax_history [0 ][fp8_meta_offsets [ "input" ] + i ] = (
184+ torch . max ( - amin , amax ). float ()
185+ )
200186 # amax of weight
201187 amin , amax = weights [i ].aminmax ()
202- fp8_meta ["scaling_fwd" ].amax_history [0 ][_GEMM_WEIGHT + i ] = torch . max (
203- - amin , amax
204- ). float ()
188+ fp8_meta ["scaling_fwd" ].amax_history [0 ][fp8_meta_offsets [ "weight" ] + i ] = (
189+ torch . max ( - amin , amax ). float ()
190+ )
205191
206192 out = torch .empty (
207193 [sum (m_splits ), weights [0 ].size (0 )],
@@ -266,11 +252,8 @@ def forward(
266252 ctx .is_first_microbatch = is_first_microbatch
267253 ctx .use_bias = use_bias
268254 ctx .sequence_parallel = sequence_parallel
269- ctx .tensor_parallel = tensor_parallel
270255 ctx .inp_shape = inp .shape
271- ctx .parallel_mode = parallel_mode
272- ctx .tp_group = tp_group
273- ctx .tp_size = tp_size
256+ ctx .fp8_meta_offsets = fp8_meta_offsets
274257 ctx .requires_dgrad = inp .requires_grad
275258 ctx .reduce_and_update_bwd_fp8_tensors = False
276259 if ctx .fp8 and requires_grad (inp , weights [0 ], biases [0 ]):
@@ -300,7 +283,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
300283 w .main_grad = main_grads [i ]
301284 weights [i ] = w
302285
303- global _GEMM_INPUT , _GEMM_WEIGHT , _GRAD_OUTPUT
304286 # preprocess grad_output
305287 grad_output = grad_output .contiguous ()
306288 grad_output_mats = torch .split (
@@ -318,13 +300,18 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
318300 fp8_cast_transpose_bgrad_fused (
319301 grad_output_mats [i ],
320302 ctx .fp8_meta ["scaling_bwd" ],
321- _GRAD_OUTPUT + i ,
303+ ctx . fp8_meta_offsets [ "grad_output" ] + i ,
322304 fp8_dtype_backward ,
323305 )
324306 )
325307 else :
326308 if not ctx .fp8_meta ["recipe" ].override_linear_precision .wgrad :
327- indices = list (range (_GRAD_OUTPUT , _GRAD_OUTPUT + ctx .num_gemms ))
309+ indices = list (
310+ range (
311+ ctx .fp8_meta_offsets ["grad_output" ],
312+ ctx .fp8_meta_offsets ["grad_output" ] + ctx .num_gemms ,
313+ )
314+ )
328315 grad_output_c , grad_output_t = fp8_multi_cast_transpose_fused (
329316 grad_output_mats ,
330317 ctx .fp8_meta ["scaling_bwd" ],
@@ -338,7 +325,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
338325 grad_output_c [i ] = cast_to_fp8 (
339326 grad_output_mats [i ],
340327 ctx .fp8_meta ["scaling_bwd" ],
341- _GRAD_OUTPUT + i ,
328+ ctx . fp8_meta_offsets [ "grad_output" ] + i ,
342329 fp8_dtype_backward ,
343330 )
344331
@@ -363,7 +350,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
363350 weights_fp8 [0 ]._fp8_dtype ,
364351 grad_output_c ,
365352 ctx .fp8_meta ["scaling_bwd" ].scale_inv ,
366- _GRAD_OUTPUT ,
353+ ctx . fp8_meta_offsets [ "grad_output" ] ,
367354 fp8_dtype_backward ,
368355 [dgrad ],
369356 ctx .activation_dtype ,
@@ -416,7 +403,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
416403 fp8_dtype_forward ,
417404 grad_output_t ,
418405 ctx .fp8_meta ["scaling_bwd" ].scale_inv ,
419- _GRAD_OUTPUT ,
406+ ctx . fp8_meta_offsets [ "grad_output" ] ,
420407 fp8_dtype_backward ,
421408 wgrad_list ,
422409 ctx .activation_dtype ,
@@ -497,12 +484,9 @@ def handle_custom_ddp_from_mcore(w, wgrad):
497484 None , # fp8_meta
498485 None , # fuse_wgrad_accumulation
499486 None , # cpu_offloading
500- None , # tp_group
501- None , # tp_size
502487 None , # sequence_parallel
503- None , # tensor_parallel
504488 None , # activation_dtype
505- None , # parallel_mode
489+ None , # fp8_meta_offsets
506490 None , # is_grad_enabled
507491 None , # weights_fp8
508492 * wgrad_list ,
@@ -536,23 +520,6 @@ class GroupedLinear(TransformerEngineBaseModule):
536520 responsibility to ensure all parameters are moved to the GPU before running the
537521 forward pass.
538522
539- Parallelism parameters
540- ----------------------
541- sequence_parallel : bool, default = `False`
542- if set to `True`, uses sequence parallelism.
543- tp_group : ProcessGroup, default = `None`
544- tensor parallel process group.
545- tp_size : int, default = 1
546- used as TP (tensor parallel) world size when TP groups are not formed during
547- initialization. In this case, users must call the
548- `set_tensor_parallel_group(tp_group)` method on the initialized module before the
549- forward pass to supply the tensor parallel group needed for tensor and sequence
550- parallel collectives.
551- parallel_mode : {None, 'column', 'row'}, default = `None`
552- used to decide whether this GroupedLinear layer is Column Parallel Linear or Row
553- Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
554- When set to `None`, no communication is performed.
555-
556523 Optimization parameters
557524 -----------------------
558525 fuse_wgrad_accumulation : bool, default = 'False'
@@ -613,8 +580,7 @@ def __init__(
613580 self .get_rng_state_tracker = get_rng_state_tracker
614581 self .rng_tracker_name = rng_tracker_name
615582
616- global _GEMM_INPUT , _GEMM_WEIGHT , _GEMM_OUTPUT
617- _GEMM_INPUT , _GEMM_WEIGHT , _GEMM_OUTPUT = 0 , num_gemms , 2 * num_gemms
583+ self ._offsets = {"input" : 0 , "weight" : num_gemms , "output" : 2 * num_gemms , "grad_output" : 0 }
618584
619585 if tp_group is None :
620586 self .tp_size = tp_size
@@ -651,7 +617,7 @@ def __init__(
651617 ),
652618 init_fn = init_method ,
653619 get_rng_state_tracker = get_rng_state_tracker ,
654- fp8_meta_index = _GEMM_WEIGHT + i ,
620+ fp8_meta_index = self . _offsets [ "weight" ] + i ,
655621 )
656622
657623 # Construct bias parameters if needed
@@ -774,7 +740,7 @@ def forward(
774740 weight_tensors_fp8 [i ] = self .get_fp8_workspace (
775741 tensor = weight_tensors [i ],
776742 fp8_meta_forward = True ,
777- fp8_meta_index = _GEMM_WEIGHT + i ,
743+ fp8_meta_index = self . _offsets [ "weight" ] + i ,
778744 cache_name = (None if is_first_microbatch is None else f"weight{ i } " ),
779745 update_workspace = update_workspace ,
780746 skip_update_flag = skip_fp8_weight_update ,
@@ -798,12 +764,9 @@ def forward(
798764 self .fp8_meta ,
799765 self .fuse_wgrad_accumulation ,
800766 CPUOffloadEnabled ,
801- self .tp_group ,
802- self .tp_size ,
803767 self .sequence_parallel ,
804- self .tp_size > 1 ,
805768 self .activation_dtype ,
806- self .parallel_mode ,
769+ self ._offsets ,
807770 torch .is_grad_enabled (),
808771 weight_tensors_fp8 ,
809772 * weight_tensors ,
0 commit comments