Skip to content

Commit 2499da5

Browse files
committed
reduce rearrange overhead
Signed-off-by: Farhad Ramezanghorbani <[email protected]>
1 parent 52cd3fa commit 2499da5

File tree

2 files changed

+20
-24
lines changed

2 files changed

+20
-24
lines changed

Diff for: nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,16 @@ def _maybe_use_fp8(self, func, *args, **kwargs):
237237
return func(*args, **kwargs)
238238

239239
def forward(self, x, layer_past=None, inference_params=None, _hyena_use_cp=True):
240-
"""Applies sequence mixing to a sequence of 1-dimensional embeddings: batch_size, seq_len, d_model.
240+
"""Applies the Hyena sequence mixing operation to input embeddings.
241241
242242
Args:
243-
u: input to the operator, in format [B, L, D]
243+
x: Input tensor of shape [L, B, D] (seq_len, batch_size, hidden_dim)
244+
layer_past: Past layer state for inference (default: None)
245+
inference_params: Parameters for inference (default: None)
246+
_hyena_use_cp: Whether to use context parallelism (default: True)
247+
248+
Returns:
249+
Tuple of (output tensor, bias)
244250
"""
245251
# CP control
246252
if _hyena_use_cp:
@@ -257,11 +263,11 @@ def forward(self, x, layer_past=None, inference_params=None, _hyena_use_cp=True)
257263
features = self.hyena_proj_conv(features, _use_cp=_proj_use_cp) # [B, D, L]
258264

259265
x1, x2, v = rearrange(
260-
features, "b (g dg p) l -> b l g p dg", p=3, g=self.num_groups_per_tp_rank
261-
).unbind(dim=3)
266+
features, "b (g dg p) l -> b (g dg) p l", p=3, g=self.num_groups_per_tp_rank
267+
).unbind(dim=2)
262268

263269
z = self.mixer(x1, x2, v)
264-
z = rearrange(z, "b l d -> l b d").contiguous()
270+
z = rearrange(z, "b d l -> l b d").contiguous()
265271

266272
y, bias = self.dense(z)
267273
return y, bias

Diff for: nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py

+9-19
Original file line numberDiff line numberDiff line change
@@ -799,23 +799,18 @@ def __init__(
799799
def forward(self, x1, x2, v, _hyena_use_cp=True):
800800
"""Shape specification for inputs and outputs.
801801
802-
Input shapes: bs, seq_length, (num_groups, group_size)
803-
Output shapes: bs, seq_length, num_groups, group_size
802+
Input shapes: bs, (num_groups, group_size), seq_length
803+
Output shapes: bs, (num_groups, group_size), seq_length
804804
"""
805-
B, L, G, DG = x1.shape
805+
B, GDG, L = x1.shape
806+
x1, x2, v = x1[..., :L], x2[..., :L], v[..., :L]
806807

807808
# CP control
808809
if _hyena_use_cp:
809810
cp_group = get_context_parallel_group()
810811
else:
811812
cp_group = None
812813

813-
x1 = rearrange(x1, "b l g dg -> b (g dg) l", g=self.num_groups, dg=self.group_dim)
814-
x2 = rearrange(x2, "b l g dg -> b (g dg) l", g=self.num_groups, dg=self.group_dim)
815-
v = rearrange(v, "b l g dg -> b (g dg) l", g=self.num_groups, dg=self.group_dim)
816-
817-
x1, x2, v = x1[..., :L], x2[..., :L], v[..., :L]
818-
819814
# The kernel length must be adjusted in CP settings
820815
_L_kernel = L if cp_group is None else L * len(torch.distributed.get_process_group_ranks(cp_group))
821816
if self.use_medium_hyena:
@@ -869,7 +864,7 @@ def forward(self, x1, x2, v, _hyena_use_cp=True):
869864
if cp_group is not None and len(torch.distributed.get_process_group_ranks(cp_group)) > 1:
870865
z = AllToAllSingleFunction.apply(z, cp_group, "full_to_split", True)
871866
# [ B, H, L // num_ranks]
872-
return rearrange(z, "b d l -> b l d")
867+
return z # [B, (G, DG), L]
873868

874869
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
875870
"""Sharded state dictionary for the ParallelHyenaOperator."""
@@ -972,15 +967,10 @@ def __init__(
972967
def forward(self, x1, x2, v, _hyena_use_cp=True):
973968
"""Shape specification for inputs and outputs.
974969
975-
Input shapes: bs, seq_length, (num_groups, group_size)
976-
Output shapes: bs, seq_length, num_groups, group_size
970+
Input shapes: bs, (num_groups, group_size), seq_length
971+
Output shapes: bs, (num_groups, group_size), seq_length
977972
"""
978-
B, L, G, DG = x1.shape
979-
980-
x1 = rearrange(x1, "b l g dg -> b (g dg) l")
981-
x2 = rearrange(x2, "b l g dg -> b (g dg) l")
982-
v = rearrange(v, "b l g dg -> b (g dg) l")
983-
973+
B, GDG, L = x1.shape
984974
x1, x2, v = x1[..., :L], x2[..., :L], v[..., :L]
985975

986976
z = x2 * v if self.pregate else v
@@ -993,7 +983,7 @@ def forward(self, x1, x2, v, _hyena_use_cp=True):
993983

994984
z = x1 * z if self.postgate else z
995985

996-
return rearrange(z, "b d l -> b l d")
986+
return z # [B, (G, DG), L]
997987

998988
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
999989
"""Sharded state dictionary for the ParallelShortHyenaOperator."""

0 commit comments

Comments
 (0)