Skip to content

Commit a5d47de

Browse files
farhadrghjomitchellnv
authored andcommitted
Remove Unexecuted Hyena Code Paths (NVIDIA#12856)
* drop multihead_forward Signed-off-by: Farhad Ramezanghorbani <[email protected]> * drop multihead_forward Signed-off-by: Farhad Ramezanghorbani <[email protected]> * drop flashfft Signed-off-by: Farhad Ramezanghorbani <[email protected]> * drop use_long_conv1d Signed-off-by: Farhad Ramezanghorbani <[email protected]> * drop cgcg Signed-off-by: Farhad Ramezanghorbani <[email protected]> * drop custom hyena mlp/short conv kernels Signed-off-by: Farhad Ramezanghorbani <[email protected]> * drop downsampling Signed-off-by: Farhad Ramezanghorbani <[email protected]> * drop unused is_mlp path Signed-off-by: Farhad Ramezanghorbani <[email protected]> * improve make_upper_case Signed-off-by: Farhad Ramezanghorbani <[email protected]> * remove unused configs Signed-off-by: Farhad Ramezanghorbani <[email protected]> * switch to inplace Signed-off-by: Farhad Ramezanghorbani <[email protected]> * consolidate fp8 logic Signed-off-by: Farhad Ramezanghorbani <[email protected]> * doc str Signed-off-by: Farhad Ramezanghorbani <[email protected]> * rm douplicate import Signed-off-by: Farhad Ramezanghorbani <[email protected]> * fix linting issues Signed-off-by: Farhad Ramezanghorbani <[email protected]> * drop unnecessary rearrange Signed-off-by: Farhad Ramezanghorbani <[email protected]> * reduce rearrange overhead Signed-off-by: Farhad Ramezanghorbani <[email protected]> * Apply isort and black reformatting Signed-off-by: farhadrgh <[email protected]> * make flake8 happy Signed-off-by: Farhad Ramezanghorbani <[email protected]> --------- Signed-off-by: Farhad Ramezanghorbani <[email protected]> Signed-off-by: farhadrgh <[email protected]> Signed-off-by: Jonathan Mitchell <[email protected]>
1 parent c9fea9f commit a5d47de

File tree

3 files changed

+128
-501
lines changed

3 files changed

+128
-501
lines changed

Diff for: nemo/collections/llm/gpt/model/megatron/hyena/hyena_config.py

-214
Original file line numberDiff line numberDiff line change
@@ -40,135 +40,6 @@ class HyenaConfig:
4040
# Weight to apply to lowercase tokens in the loss function, 1.0 is no reweighting.
4141
# """
4242

43-
use_flashfft: bool = False
44-
"""
45-
Use flashfftconv instead of torch fft kernel (requires installation of flashfftconv)for hyena
46-
"""
47-
48-
use_cgcg: bool = False
49-
"""
50-
Use cgcg (chunked gate-conv-gate) kernel for hyena
51-
"""
52-
53-
use_cgcg_short: bool = False
54-
"""
55-
Use cgcg (chunked gate-conv-gate) kernel for hyena short conv
56-
"""
57-
58-
use_cgcg_mlp: bool = False
59-
"""
60-
Use cgcg (chunked gate-conv-gate) kernel for hyena mlp
61-
"""
62-
63-
cgcg_dtype: str = "bfloat16"
64-
"""
65-
dtype to use within cgcg kernel
66-
"""
67-
#
68-
# cgcg_fwd_autotune: bool = False
69-
# """
70-
# Whether to autotune cgcg fwd kernel
71-
#
72-
# @jeromeku: Note autotuning fwd kernel is unstable,
73-
# use pre-tuned config for now.
74-
# """
75-
76-
cgcg_medium_fwd_kernel_config_chunk_size: int = 128
77-
"""
78-
cgcg fwd medium conv kernel config chunk size
79-
"""
80-
cgcg_medium_fwd_kernel_config_block_d: int = 128
81-
"""
82-
cgcg fwd medium conv kernel config block d tile size
83-
"""
84-
85-
cgcg_medium_fwd_kernel_config_threadblock_swizzle: str = "row"
86-
"""
87-
cgcg fwd medium conv kernel config threadblock swizzle type
88-
"""
89-
cgcg_medium_fwd_kernel_config_chunk_tiles_per_program: int = 3
90-
"""
91-
cgcg fwd medium conv kernel config chunk tiles per program
92-
"""
93-
94-
cgcg_medium_fwd_kernel_config_num_warps: int = 4
95-
"""
96-
cgcg fwd short conv kernel config num warps
97-
"""
98-
99-
cgcg_medium_fwd_kernel_config_num_stages: int = 3
100-
"""
101-
cgcg fwd medium conv kernel config num mma pipeline stages
102-
"""
103-
104-
cgcg_short_fwd_kernel_config_chunk_size: int = 128
105-
"""
106-
cgcg fwd short conv kernel config chunk size
107-
"""
108-
cgcg_short_fwd_kernel_config_block_d: int = 128
109-
"""
110-
cgcg fwd short conv kernel config block d tile size
111-
"""
112-
113-
cgcg_short_fwd_kernel_config_threadblock_swizzle: str = "row"
114-
"""
115-
cgcg fwd short conv kernel config threadblock swizzle type
116-
"""
117-
cgcg_short_fwd_kernel_config_chunk_tiles_per_program: int = 1
118-
"""
119-
cgcg fwd short conv kernel config chunk tiles per program
120-
"""
121-
122-
cgcg_short_fwd_kernel_config_num_warps: int = 4
123-
"""
124-
cgcg fwd short conv kernel config num warps
125-
"""
126-
127-
cgcg_short_fwd_kernel_config_num_stages: int = 1
128-
"""
129-
cgcg fwd short conv kernel config num mma pipeline stages
130-
"""
131-
132-
cgcg_bwd_autotune: bool = True
133-
"""
134-
Whether to autotune cgcg bwd kernel
135-
"""
136-
137-
cgcg_fused_bwd: bool = True
138-
"""
139-
Whether to use fused cgcg bwd kernel
140-
"""
141-
142-
cgcg_bwd_kernel_config_pre_conv_block_x: int = 128
143-
"""
144-
cgcg bwd pre_conv kernel config block x tile size
145-
"""
146-
147-
cgcg_bwd_kernel_config_pre_conv_block_y: int = 128
148-
"""
149-
cgcg bwd pre_conv kernel config block y tile size
150-
"""
151-
152-
cgcg_bwd_kernel_config_pre_conv_num_warps: int = 8
153-
"""
154-
cgcg bwd pre_conv kernel config num warps
155-
"""
156-
157-
cgcg_bwd_kernel_config_post_conv_block_x: int = 32
158-
"""
159-
cgcg bwd post conv kernel config block x tile size
160-
"""
161-
162-
cgcg_bwd_kernel_config_post_conv_block_y: int = 128
163-
"""
164-
cgcg bwd post conv kernel config block y tile size
165-
"""
166-
167-
cgcg_bwd_kernel_config_post_conv_num_warps: int = 4
168-
"""
169-
cgcg bwd post conv kernel config num warps
170-
"""
171-
17243
short_conv_L: int = 3
17344
"""
17445
For Hyena models, length of the short convolution.
@@ -191,13 +62,6 @@ class HyenaConfig:
19162
Use external fast heads in Hyena mixer (reduce BEFORE fftconv)
19263
"""
19364

194-
use_slow_heads: bool = False
195-
"""
196-
Use external outer-product heads in Hyena.
197-
"""
198-
199-
use_long_conv1d: bool = False
200-
20165
num_groups_hyena: int = None
20266
"""
20367
Determines number of unique filters to have, for the hyena long filter.
@@ -213,11 +77,6 @@ class HyenaConfig:
21377
Determines number of unique filters to have, for the hyena short filter.
21478
"""
21579

216-
num_groups_hyena_mlp: int = None # TODO: Possibly remove, only used if is_mlp is True
217-
"""
218-
Determines number of unique filters to have, for the hyena mlp (filter).
219-
"""
220-
22180
use_depthwise_short_conv_grouping: bool = True
22281
"""
22382
Whether to use depthwise convolution grouping for short conv and hyena mlp filters.
@@ -237,48 +96,14 @@ class HyenaConfig:
23796
For medium hyena filters specifically, None defaults ot same as hyena_filter_cls (long filters).
23897
"""
23998

240-
hyena_filter_r_max: float = 0.99 # TODO: Possibly remove, only used in ParallelComplexModalFilter
241-
242-
hyena_filter_r_min: float = 0.5 # TODO: Possibly remove, only used in ParallelComplexModalFilter
243-
244-
hyena_filter_emb_dim: int = 33 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter
245-
246-
hyena_filter_fast_decay: float = 0.3 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter
247-
248-
hyena_filter_slow_decay: float = 1.2 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter
249-
25099
hyena_filter_order: int = 16
251100

252-
hyena_filter_num_inner_mlps: int = 2 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter
253-
254-
hyena_filter_w: int = 14 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter
255-
256-
hyena_filter_wd: float = 0.0 # TODO: Where to override WD value for filters?
257-
258-
hyena_filter_omega_0: float = 1 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter
259-
260-
hyena_pos_emb: str = "fourier_fixed" # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter
261-
262101
explicit_filter_decay_preset: str = "weak"
263102

264-
modal_residue_factors: int = 3 # TODO: Possibly remove, only used in ImplicitRealModelFilter
265-
266-
modal_pole_factors: int = 3 # TODO: Possibly remove, only used in ImplicitRealModelFilter
267-
268103
modal_gamma_min: float = 0.01
269104

270105
modal_gamma_max: float = 0.1
271106

272-
use_custom_hyena_short_kernel: bool = False
273-
"""
274-
Use a custom causal conv layer for the hyena short conv layer.
275-
"""
276-
277-
use_custom_hyena_mlp_kernel: bool = False # TODO: Possibly remove - only relevant if is_mlp is True
278-
"""
279-
Use a custom causal conv layer for the hyena short conv layer.
280-
"""
281-
282107
bidirectional: bool = False
283108
"""
284109
A bidirectional version of hyena fftconv
@@ -304,31 +129,6 @@ class HyenaConfig:
304129
Use a custom causal conv layer for the hyena short conv layer.
305130
"""
306131

307-
hyena_mlp_len: int = 7 # TODO: Possibly remove, only used if is_mlp is True
308-
"""
309-
Length of filter used inside the hyena mlp layer. Defaults to hyena_short_conv_len if not provided.
310-
"""
311-
312-
fast_hyena_mlp_conv: bool = False # TODO: Possibly remove, only used if is_mlp is True
313-
"""
314-
Use a custom causal conv layer for the hyena MLP layer.
315-
"""
316-
317-
hyena_mlp_expansion_factor: float = 1.0 # TODO: Possibly remove, only used if is_mlp is True
318-
"""
319-
Factor to expand the projections width within hyena MLP layers only.
320-
"""
321-
322-
hyena_mlp_pregate: bool = True # TODO: Possibly remove, only used if is_mlp is True
323-
"""
324-
Use a pre-gate in the hyena MLP layer.
325-
"""
326-
327-
hyena_mlp_postgate: bool = True # TODO: Possibly remove, only used if is_mlp is True
328-
"""
329-
Use a post-gate in the hyena MLP layer.
330-
"""
331-
332132
hyena_short_conv_pregate: bool = True
333133
"""
334134
Use a pre-gate in the hyena short conv layer.
@@ -342,17 +142,3 @@ class HyenaConfig:
342142
proj_groups: int = 1
343143

344144
grouped_attention: bool = False
345-
346-
# mlp_type: str = "regular" # TODO: In Savanna setting this to 'short_hyena' uses hyena for MLP (is_mlp == True)
347-
# """
348-
# Types:
349-
# regular: Megatron implementation
350-
# llama: LLaMA MLP (SiLU-gated MLP)
351-
# short_hyena
352-
# identity
353-
# """
354-
#
355-
# make_gated_mlp_multiple_of: int = 16 # TODO: Use this or just have user calculate ffn_size themselves?
356-
# """
357-
# Set the ff_dim to be a multiple of this value for llama mlp. Useful for sharding / using model parallel properly.
358-
# """

0 commit comments

Comments
 (0)