@@ -799,23 +799,18 @@ def __init__(
799
799
def forward (self , x1 , x2 , v , _hyena_use_cp = True ):
800
800
"""Shape specification for inputs and outputs.
801
801
802
- Input shapes: bs, seq_length, (num_groups, group_size)
803
- Output shapes: bs, seq_length, num_groups, group_size
802
+ Input shapes: bs, (num_groups, group_size), seq_length
803
+ Output shapes: bs, ( num_groups, group_size), seq_length
804
804
"""
805
- B , L , G , DG = x1 .shape
805
+ B , GDG , L = x1 .shape
806
+ x1 , x2 , v = x1 [..., :L ], x2 [..., :L ], v [..., :L ]
806
807
807
808
# CP control
808
809
if _hyena_use_cp :
809
810
cp_group = get_context_parallel_group ()
810
811
else :
811
812
cp_group = None
812
813
813
- x1 = rearrange (x1 , "b l g dg -> b (g dg) l" , g = self .num_groups , dg = self .group_dim )
814
- x2 = rearrange (x2 , "b l g dg -> b (g dg) l" , g = self .num_groups , dg = self .group_dim )
815
- v = rearrange (v , "b l g dg -> b (g dg) l" , g = self .num_groups , dg = self .group_dim )
816
-
817
- x1 , x2 , v = x1 [..., :L ], x2 [..., :L ], v [..., :L ]
818
-
819
814
# The kernel length must be adjusted in CP settings
820
815
_L_kernel = L if cp_group is None else L * len (torch .distributed .get_process_group_ranks (cp_group ))
821
816
if self .use_medium_hyena :
@@ -869,7 +864,7 @@ def forward(self, x1, x2, v, _hyena_use_cp=True):
869
864
if cp_group is not None and len (torch .distributed .get_process_group_ranks (cp_group )) > 1 :
870
865
z = AllToAllSingleFunction .apply (z , cp_group , "full_to_split" , True )
871
866
# [ B, H, L // num_ranks]
872
- return rearrange ( z , "b d l -> b l d" )
867
+ return z # [B, (G, DG), L]
873
868
874
869
def sharded_state_dict (self , prefix = '' , sharded_offsets = (), metadata = None ):
875
870
"""Sharded state dictionary for the ParallelHyenaOperator."""
@@ -972,15 +967,10 @@ def __init__(
972
967
def forward (self , x1 , x2 , v , _hyena_use_cp = True ):
973
968
"""Shape specification for inputs and outputs.
974
969
975
- Input shapes: bs, seq_length, (num_groups, group_size)
976
- Output shapes: bs, seq_length, num_groups, group_size
970
+ Input shapes: bs, (num_groups, group_size), seq_length
971
+ Output shapes: bs, ( num_groups, group_size), seq_length
977
972
"""
978
- B , L , G , DG = x1 .shape
979
-
980
- x1 = rearrange (x1 , "b l g dg -> b (g dg) l" )
981
- x2 = rearrange (x2 , "b l g dg -> b (g dg) l" )
982
- v = rearrange (v , "b l g dg -> b (g dg) l" )
983
-
973
+ B , GDG , L = x1 .shape
984
974
x1 , x2 , v = x1 [..., :L ], x2 [..., :L ], v [..., :L ]
985
975
986
976
z = x2 * v if self .pregate else v
@@ -993,7 +983,7 @@ def forward(self, x1, x2, v, _hyena_use_cp=True):
993
983
994
984
z = x1 * z if self .postgate else z
995
985
996
- return rearrange ( z , "b d l -> b l d" )
986
+ return z # [B, (G, DG), L]
997
987
998
988
def sharded_state_dict (self , prefix = '' , sharded_offsets = (), metadata = None ):
999
989
"""Sharded state dictionary for the ParallelShortHyenaOperator."""
0 commit comments