@@ -66,6 +66,16 @@ def replace(self, **kwargs):
6666 return dataclasses .replace (self , ** kwargs )
6767
6868
69+ def maybe_with_partitioning (
70+ init_fn : nnx .Initializer ,
71+ names : tuple [str | None , ...],
72+ config : TransformerConfig ,
73+ ) -> nnx .Initializer :
74+ if all (name is None for name in names ):
75+ return init_fn
76+ return nnx .with_partitioning (init_fn , names )
77+
78+
6979def shift_right (x : jax .Array , axis : int = 1 ):
7080 """Shift the input to the right by padding and slicing on axis."""
7181 pad_widths : list [tuple [int , int ]] = [(0 , 0 )] * len (x .shape )
@@ -202,32 +212,36 @@ def __init__(self, config: TransformerConfig, *, rngs: nnx.Rngs):
202212 self .config = config
203213
204214 self .linear1 = nnx .Linear (
205- config .emb_dim ,
206- config .mlp_dim ,
207- dtype = config .dtype ,
208- kernel_init = nnx .with_partitioning (
209- config .kernel_init ,
210- config .axis_rules ('embed' , 'mlp' ),
211- ),
212- bias_init = nnx .with_partitioning (
213- config .bias_init ,
214- config .axis_rules ('mlp' ),
215- ),
216- rngs = rngs ,
215+ config .emb_dim ,
216+ config .mlp_dim ,
217+ dtype = config .dtype ,
218+ kernel_init = maybe_with_partitioning (
219+ config .kernel_init ,
220+ config .axis_rules ('embed' , 'mlp' ),
221+ config ,
222+ ),
223+ bias_init = maybe_with_partitioning (
224+ config .bias_init ,
225+ config .axis_rules ('mlp' ),
226+ config ,
227+ ),
228+ rngs = rngs ,
217229 )
218230 self .linear2 = nnx .Linear (
219- config .mlp_dim ,
220- config .emb_dim ,
221- dtype = config .dtype ,
222- kernel_init = nnx .with_partitioning (
223- config .kernel_init ,
224- config .axis_rules ('mlp' , 'embed' ),
225- ),
226- bias_init = nnx .with_partitioning (
227- config .bias_init ,
228- config .axis_rules ('embed' ),
229- ),
230- rngs = rngs ,
231+ config .mlp_dim ,
232+ config .emb_dim ,
233+ dtype = config .dtype ,
234+ kernel_init = maybe_with_partitioning (
235+ config .kernel_init ,
236+ config .axis_rules ('mlp' , 'embed' ),
237+ config ,
238+ ),
239+ bias_init = maybe_with_partitioning (
240+ config .bias_init ,
241+ config .axis_rules ('embed' ),
242+ config ,
243+ ),
244+ rngs = rngs ,
231245 )
232246 self .dropout = nnx .Dropout (rate = config .dropout_rate )
233247
@@ -252,47 +266,51 @@ def __init__(self, config: TransformerConfig, *, rngs: nnx.Rngs):
252266 self .config = config
253267
254268 self .ln1 = nnx .LayerNorm (
255- num_features = config .emb_dim ,
256- dtype = config .dtype ,
257- bias_init = nnx .with_partitioning (
258- nnx .initializers .zeros_init (),
259- config .axis_rules ('embed' ),
260- ),
261- scale_init = nnx .with_partitioning (
262- nnx .initializers .ones_init (),
263- config .axis_rules ('embed' ),
264- ),
265- rngs = rngs ,
269+ num_features = config .emb_dim ,
270+ dtype = config .dtype ,
271+ bias_init = maybe_with_partitioning (
272+ nnx .initializers .zeros_init (),
273+ config .axis_rules ('embed' ),
274+ config ,
275+ ),
276+ scale_init = maybe_with_partitioning (
277+ nnx .initializers .ones_init (),
278+ config .axis_rules ('embed' ),
279+ config ,
280+ ),
281+ rngs = rngs ,
266282 )
267283 self .ln2 = nnx .LayerNorm (
268- num_features = config .emb_dim ,
269- dtype = config .dtype ,
270- bias_init = nnx .with_partitioning (
271- nnx .initializers .zeros_init (),
272- config .axis_rules ('embed' ),
273- ),
274- scale_init = nnx .with_partitioning (
275- nnx .initializers .ones_init (),
276- config .axis_rules ('embed' ),
277- ),
278- rngs = rngs ,
284+ num_features = config .emb_dim ,
285+ dtype = config .dtype ,
286+ bias_init = maybe_with_partitioning (
287+ nnx .initializers .zeros_init (),
288+ config .axis_rules ('embed' ),
289+ config ,
290+ ),
291+ scale_init = maybe_with_partitioning (
292+ nnx .initializers .ones_init (),
293+ config .axis_rules ('embed' ),
294+ config ,
295+ ),
296+ rngs = rngs ,
279297 )
280298 self .attention = nnx .MultiHeadAttention (
281- num_heads = config .num_heads ,
282- in_features = config .emb_dim ,
283- qkv_features = config .qkv_dim ,
284- dtype = config .dtype ,
285- kernel_init = nnx . with_partitioning (
286- config .kernel_init , config .axis_rules ('embed' , 'kv' )
287- ),
288- bias_init = nnx . with_partitioning (
289- config .bias_init , config .axis_rules ('embed' )
290- ),
291- use_bias = False ,
292- broadcast_dropout = False ,
293- dropout_rate = config .attention_dropout_rate ,
294- rngs = rngs ,
295- keep_rngs = False ,
299+ num_heads = config .num_heads ,
300+ in_features = config .emb_dim ,
301+ qkv_features = config .qkv_dim ,
302+ dtype = config .dtype ,
303+ kernel_init = maybe_with_partitioning (
304+ config .kernel_init , config .axis_rules ('embed' , 'kv' ), config
305+ ),
306+ bias_init = maybe_with_partitioning (
307+ config .bias_init , config .axis_rules ('embed' ), config
308+ ),
309+ use_bias = False ,
310+ broadcast_dropout = False ,
311+ dropout_rate = config .attention_dropout_rate ,
312+ rngs = rngs ,
313+ keep_rngs = False ,
296314 )
297315 self .mlp = MlpBlock (config = config , rngs = rngs )
298316 self .dropout = nnx .Dropout (rate = config .dropout_rate )
@@ -348,13 +366,14 @@ def __init__(
348366 # Target Embedding
349367 if self .shared_embedding is None :
350368 self .output_embed = nnx .Embed (
351- num_embeddings = config .output_vocab_size ,
352- features = config .emb_dim ,
353- embedding_init = nnx .with_partitioning (
354- nnx .initializers .normal (stddev = 1.0 ),
355- config .axis_rules ('vocab' , 'embed' ),
356- ),
357- rngs = rngs ,
369+ num_embeddings = config .output_vocab_size ,
370+ features = config .emb_dim ,
371+ embedding_init = maybe_with_partitioning (
372+ nnx .initializers .normal (stddev = 1.0 ),
373+ config .axis_rules ('vocab' , 'embed' ),
374+ config ,
375+ ),
376+ rngs = rngs ,
358377 )
359378 else :
360379 self .output_embed = self .shared_embedding
@@ -366,28 +385,28 @@ def __init__(
366385 setattr (self , f'encoderdecoderblock_{ idx } ' , layer )
367386
368387 self .encoderdecoder_norm = nnx .LayerNorm (
369- num_features = config .emb_dim ,
370- dtype = config .dtype ,
371- bias_init = nnx .with_partitioning (
372- nnx .initializers .zeros_init (), config .axis_rules ('embed' )
373- ),
374- scale_init = nnx .with_partitioning (
375- nnx .initializers .ones_init (), config .axis_rules ('embed' )
376- ),
377- rngs = rngs ,
378- )
379- if not config .logits_via_embedding :
380- self .logitdense = nnx .Linear (
381- in_features = config .emb_dim ,
382- out_features = config .output_vocab_size ,
388+ num_features = config .emb_dim ,
383389 dtype = config .dtype ,
384- kernel_init = nnx . with_partitioning (
385- config . kernel_init , config .axis_rules ('embed' , 'vocab' )
390+ bias_init = maybe_with_partitioning (
391+ nnx . initializers . zeros_init () , config .axis_rules ('embed' ), config
386392 ),
387- bias_init = nnx . with_partitioning (
388- config . bias_init , config .axis_rules ('vocab' )
393+ scale_init = maybe_with_partitioning (
394+ nnx . initializers . ones_init () , config .axis_rules ('embed' ), config
389395 ),
390396 rngs = rngs ,
397+ )
398+ if not config .logits_via_embedding :
399+ self .logitdense = nnx .Linear (
400+ in_features = config .emb_dim ,
401+ out_features = config .output_vocab_size ,
402+ dtype = config .dtype ,
403+ kernel_init = maybe_with_partitioning (
404+ config .kernel_init , config .axis_rules ('embed' , 'vocab' ), config
405+ ),
406+ bias_init = maybe_with_partitioning (
407+ config .bias_init , config .axis_rules ('vocab' ), config
408+ ),
409+ rngs = rngs ,
391410 )
392411 else :
393412 self .logitdense = None
0 commit comments