@@ -40,135 +40,6 @@ class HyenaConfig:
40
40
# Weight to apply to lowercase tokens in the loss function, 1.0 is no reweighting.
41
41
# """
42
42
43
- use_flashfft : bool = False
44
- """
45
- Use flashfftconv instead of torch fft kernel (requires installation of flashfftconv)for hyena
46
- """
47
-
48
- use_cgcg : bool = False
49
- """
50
- Use cgcg (chunked gate-conv-gate) kernel for hyena
51
- """
52
-
53
- use_cgcg_short : bool = False
54
- """
55
- Use cgcg (chunked gate-conv-gate) kernel for hyena short conv
56
- """
57
-
58
- use_cgcg_mlp : bool = False
59
- """
60
- Use cgcg (chunked gate-conv-gate) kernel for hyena mlp
61
- """
62
-
63
- cgcg_dtype : str = "bfloat16"
64
- """
65
- dtype to use within cgcg kernel
66
- """
67
- #
68
- # cgcg_fwd_autotune: bool = False
69
- # """
70
- # Whether to autotune cgcg fwd kernel
71
- #
72
- # @jeromeku: Note autotuning fwd kernel is unstable,
73
- # use pre-tuned config for now.
74
- # """
75
-
76
- cgcg_medium_fwd_kernel_config_chunk_size : int = 128
77
- """
78
- cgcg fwd medium conv kernel config chunk size
79
- """
80
- cgcg_medium_fwd_kernel_config_block_d : int = 128
81
- """
82
- cgcg fwd medium conv kernel config block d tile size
83
- """
84
-
85
- cgcg_medium_fwd_kernel_config_threadblock_swizzle : str = "row"
86
- """
87
- cgcg fwd medium conv kernel config threadblock swizzle type
88
- """
89
- cgcg_medium_fwd_kernel_config_chunk_tiles_per_program : int = 3
90
- """
91
- cgcg fwd medium conv kernel config chunk tiles per program
92
- """
93
-
94
- cgcg_medium_fwd_kernel_config_num_warps : int = 4
95
- """
96
- cgcg fwd short conv kernel config num warps
97
- """
98
-
99
- cgcg_medium_fwd_kernel_config_num_stages : int = 3
100
- """
101
- cgcg fwd medium conv kernel config num mma pipeline stages
102
- """
103
-
104
- cgcg_short_fwd_kernel_config_chunk_size : int = 128
105
- """
106
- cgcg fwd short conv kernel config chunk size
107
- """
108
- cgcg_short_fwd_kernel_config_block_d : int = 128
109
- """
110
- cgcg fwd short conv kernel config block d tile size
111
- """
112
-
113
- cgcg_short_fwd_kernel_config_threadblock_swizzle : str = "row"
114
- """
115
- cgcg fwd short conv kernel config threadblock swizzle type
116
- """
117
- cgcg_short_fwd_kernel_config_chunk_tiles_per_program : int = 1
118
- """
119
- cgcg fwd short conv kernel config chunk tiles per program
120
- """
121
-
122
- cgcg_short_fwd_kernel_config_num_warps : int = 4
123
- """
124
- cgcg fwd short conv kernel config num warps
125
- """
126
-
127
- cgcg_short_fwd_kernel_config_num_stages : int = 1
128
- """
129
- cgcg fwd short conv kernel config num mma pipeline stages
130
- """
131
-
132
- cgcg_bwd_autotune : bool = True
133
- """
134
- Whether to autotune cgcg bwd kernel
135
- """
136
-
137
- cgcg_fused_bwd : bool = True
138
- """
139
- Whether to use fused cgcg bwd kernel
140
- """
141
-
142
- cgcg_bwd_kernel_config_pre_conv_block_x : int = 128
143
- """
144
- cgcg bwd pre_conv kernel config block x tile size
145
- """
146
-
147
- cgcg_bwd_kernel_config_pre_conv_block_y : int = 128
148
- """
149
- cgcg bwd pre_conv kernel config block y tile size
150
- """
151
-
152
- cgcg_bwd_kernel_config_pre_conv_num_warps : int = 8
153
- """
154
- cgcg bwd pre_conv kernel config num warps
155
- """
156
-
157
- cgcg_bwd_kernel_config_post_conv_block_x : int = 32
158
- """
159
- cgcg bwd post conv kernel config block x tile size
160
- """
161
-
162
- cgcg_bwd_kernel_config_post_conv_block_y : int = 128
163
- """
164
- cgcg bwd post conv kernel config block y tile size
165
- """
166
-
167
- cgcg_bwd_kernel_config_post_conv_num_warps : int = 4
168
- """
169
- cgcg bwd post conv kernel config num warps
170
- """
171
-
172
43
short_conv_L : int = 3
173
44
"""
174
45
For Hyena models, length of the short convolution.
@@ -191,13 +62,6 @@ class HyenaConfig:
191
62
Use external fast heads in Hyena mixer (reduce BEFORE fftconv)
192
63
"""
193
64
194
- use_slow_heads : bool = False
195
- """
196
- Use external outer-product heads in Hyena.
197
- """
198
-
199
- use_long_conv1d : bool = False
200
-
201
65
num_groups_hyena : int = None
202
66
"""
203
67
Determines number of unique filters to have, for the hyena long filter.
@@ -213,11 +77,6 @@ class HyenaConfig:
213
77
Determines number of unique filters to have, for the hyena short filter.
214
78
"""
215
79
216
- num_groups_hyena_mlp : int = None # TODO: Possibly remove, only used if is_mlp is True
217
- """
218
- Determines number of unique filters to have, for the hyena mlp (filter).
219
- """
220
-
221
80
use_depthwise_short_conv_grouping : bool = True
222
81
"""
223
82
Whether to use depthwise convolution grouping for short conv and hyena mlp filters.
@@ -237,48 +96,14 @@ class HyenaConfig:
237
96
For medium hyena filters specifically, None defaults ot same as hyena_filter_cls (long filters).
238
97
"""
239
98
240
- hyena_filter_r_max : float = 0.99 # TODO: Possibly remove, only used in ParallelComplexModalFilter
241
-
242
- hyena_filter_r_min : float = 0.5 # TODO: Possibly remove, only used in ParallelComplexModalFilter
243
-
244
- hyena_filter_emb_dim : int = 33 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter
245
-
246
- hyena_filter_fast_decay : float = 0.3 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter
247
-
248
- hyena_filter_slow_decay : float = 1.2 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter
249
-
250
99
hyena_filter_order : int = 16
251
100
252
- hyena_filter_num_inner_mlps : int = 2 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter
253
-
254
- hyena_filter_w : int = 14 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter
255
-
256
- hyena_filter_wd : float = 0.0 # TODO: Where to override WD value for filters?
257
-
258
- hyena_filter_omega_0 : float = 1 # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter
259
-
260
- hyena_pos_emb : str = "fourier_fixed" # TODO: Possibly remove, only used in ParallelImplicitFreeformFilter
261
-
262
101
explicit_filter_decay_preset : str = "weak"
263
102
264
- modal_residue_factors : int = 3 # TODO: Possibly remove, only used in ImplicitRealModelFilter
265
-
266
- modal_pole_factors : int = 3 # TODO: Possibly remove, only used in ImplicitRealModelFilter
267
-
268
103
modal_gamma_min : float = 0.01
269
104
270
105
modal_gamma_max : float = 0.1
271
106
272
- use_custom_hyena_short_kernel : bool = False
273
- """
274
- Use a custom causal conv layer for the hyena short conv layer.
275
- """
276
-
277
- use_custom_hyena_mlp_kernel : bool = False # TODO: Possibly remove - only relevant if is_mlp is True
278
- """
279
- Use a custom causal conv layer for the hyena short conv layer.
280
- """
281
-
282
107
bidirectional : bool = False
283
108
"""
284
109
A bidirectional version of hyena fftconv
@@ -304,31 +129,6 @@ class HyenaConfig:
304
129
Use a custom causal conv layer for the hyena short conv layer.
305
130
"""
306
131
307
- hyena_mlp_len : int = 7 # TODO: Possibly remove, only used if is_mlp is True
308
- """
309
- Length of filter used inside the hyena mlp layer. Defaults to hyena_short_conv_len if not provided.
310
- """
311
-
312
- fast_hyena_mlp_conv : bool = False # TODO: Possibly remove, only used if is_mlp is True
313
- """
314
- Use a custom causal conv layer for the hyena MLP layer.
315
- """
316
-
317
- hyena_mlp_expansion_factor : float = 1.0 # TODO: Possibly remove, only used if is_mlp is True
318
- """
319
- Factor to expand the projections width within hyena MLP layers only.
320
- """
321
-
322
- hyena_mlp_pregate : bool = True # TODO: Possibly remove, only used if is_mlp is True
323
- """
324
- Use a pre-gate in the hyena MLP layer.
325
- """
326
-
327
- hyena_mlp_postgate : bool = True # TODO: Possibly remove, only used if is_mlp is True
328
- """
329
- Use a post-gate in the hyena MLP layer.
330
- """
331
-
332
132
hyena_short_conv_pregate : bool = True
333
133
"""
334
134
Use a pre-gate in the hyena short conv layer.
@@ -342,17 +142,3 @@ class HyenaConfig:
342
142
proj_groups : int = 1
343
143
344
144
grouped_attention : bool = False
345
-
346
- # mlp_type: str = "regular" # TODO: In Savanna setting this to 'short_hyena' uses hyena for MLP (is_mlp == True)
347
- # """
348
- # Types:
349
- # regular: Megatron implementation
350
- # llama: LLaMA MLP (SiLU-gated MLP)
351
- # short_hyena
352
- # identity
353
- # """
354
- #
355
- # make_gated_mlp_multiple_of: int = 16 # TODO: Use this or just have user calculate ffn_size themselves?
356
- # """
357
- # Set the ff_dim to be a multiple of this value for llama mlp. Useful for sharding / using model parallel properly.
358
- # """
0 commit comments