Skip to content

Commit 62ee4c2

Browse files
danielsuoFlax Authors
authored andcommitted
[flax:examples:lm1b_nnx] Update example to work internally. #jax-fixit.
PiperOrigin-RevId: 838909002
1 parent ec85cdd commit 62ee4c2

File tree

6 files changed

+145
-154
lines changed

6 files changed

+145
-154
lines changed

examples/lm1b_nnx/configs/default.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ class Config:
107107
# Parallelism
108108
mesh_axes: tuple[str, ...] = ('data', 'fsdp', 'tensor')
109109
axis_rules: MeshRules = MeshRules(
110-
embed='fsdp',
111-
mlp='tensor',
112-
kv='tensor',
113-
vocab='tensor',
110+
embed=None,
111+
mlp=None,
112+
kv=None,
113+
vocab=None,
114114
)
115115
data_sharding: tuple[str, ...] = ('data',)
116116
# One axis for each parallelism type may hold a placeholder (-1)

examples/lm1b_nnx/models.py

Lines changed: 105 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,16 @@ def replace(self, **kwargs):
6666
return dataclasses.replace(self, **kwargs)
6767

6868

69+
def maybe_with_partitioning(
70+
init_fn: nnx.Initializer,
71+
names: tuple[str | None, ...],
72+
config: TransformerConfig,
73+
) -> nnx.Initializer:
74+
if all(name is None for name in names):
75+
return init_fn
76+
return nnx.with_partitioning(init_fn, names)
77+
78+
6979
def shift_right(x: jax.Array, axis: int = 1):
7080
"""Shift the input to the right by padding and slicing on axis."""
7181
pad_widths: list[tuple[int, int]] = [(0, 0)] * len(x.shape)
@@ -202,32 +212,36 @@ def __init__(self, config: TransformerConfig, *, rngs: nnx.Rngs):
202212
self.config = config
203213

204214
self.linear1 = nnx.Linear(
205-
config.emb_dim,
206-
config.mlp_dim,
207-
dtype=config.dtype,
208-
kernel_init=nnx.with_partitioning(
209-
config.kernel_init,
210-
config.axis_rules('embed', 'mlp'),
211-
),
212-
bias_init=nnx.with_partitioning(
213-
config.bias_init,
214-
config.axis_rules('mlp'),
215-
),
216-
rngs=rngs,
215+
config.emb_dim,
216+
config.mlp_dim,
217+
dtype=config.dtype,
218+
kernel_init=maybe_with_partitioning(
219+
config.kernel_init,
220+
config.axis_rules('embed', 'mlp'),
221+
config,
222+
),
223+
bias_init=maybe_with_partitioning(
224+
config.bias_init,
225+
config.axis_rules('mlp'),
226+
config,
227+
),
228+
rngs=rngs,
217229
)
218230
self.linear2 = nnx.Linear(
219-
config.mlp_dim,
220-
config.emb_dim,
221-
dtype=config.dtype,
222-
kernel_init=nnx.with_partitioning(
223-
config.kernel_init,
224-
config.axis_rules('mlp', 'embed'),
225-
),
226-
bias_init=nnx.with_partitioning(
227-
config.bias_init,
228-
config.axis_rules('embed'),
229-
),
230-
rngs=rngs,
231+
config.mlp_dim,
232+
config.emb_dim,
233+
dtype=config.dtype,
234+
kernel_init=maybe_with_partitioning(
235+
config.kernel_init,
236+
config.axis_rules('mlp', 'embed'),
237+
config,
238+
),
239+
bias_init=maybe_with_partitioning(
240+
config.bias_init,
241+
config.axis_rules('embed'),
242+
config,
243+
),
244+
rngs=rngs,
231245
)
232246
self.dropout = nnx.Dropout(rate=config.dropout_rate)
233247

@@ -252,47 +266,51 @@ def __init__(self, config: TransformerConfig, *, rngs: nnx.Rngs):
252266
self.config = config
253267

254268
self.ln1 = nnx.LayerNorm(
255-
num_features=config.emb_dim,
256-
dtype=config.dtype,
257-
bias_init=nnx.with_partitioning(
258-
nnx.initializers.zeros_init(),
259-
config.axis_rules('embed'),
260-
),
261-
scale_init=nnx.with_partitioning(
262-
nnx.initializers.ones_init(),
263-
config.axis_rules('embed'),
264-
),
265-
rngs=rngs,
269+
num_features=config.emb_dim,
270+
dtype=config.dtype,
271+
bias_init=maybe_with_partitioning(
272+
nnx.initializers.zeros_init(),
273+
config.axis_rules('embed'),
274+
config,
275+
),
276+
scale_init=maybe_with_partitioning(
277+
nnx.initializers.ones_init(),
278+
config.axis_rules('embed'),
279+
config,
280+
),
281+
rngs=rngs,
266282
)
267283
self.ln2 = nnx.LayerNorm(
268-
num_features=config.emb_dim,
269-
dtype=config.dtype,
270-
bias_init=nnx.with_partitioning(
271-
nnx.initializers.zeros_init(),
272-
config.axis_rules('embed'),
273-
),
274-
scale_init=nnx.with_partitioning(
275-
nnx.initializers.ones_init(),
276-
config.axis_rules('embed'),
277-
),
278-
rngs=rngs,
284+
num_features=config.emb_dim,
285+
dtype=config.dtype,
286+
bias_init=maybe_with_partitioning(
287+
nnx.initializers.zeros_init(),
288+
config.axis_rules('embed'),
289+
config,
290+
),
291+
scale_init=maybe_with_partitioning(
292+
nnx.initializers.ones_init(),
293+
config.axis_rules('embed'),
294+
config,
295+
),
296+
rngs=rngs,
279297
)
280298
self.attention = nnx.MultiHeadAttention(
281-
num_heads=config.num_heads,
282-
in_features=config.emb_dim,
283-
qkv_features=config.qkv_dim,
284-
dtype=config.dtype,
285-
kernel_init=nnx.with_partitioning(
286-
config.kernel_init, config.axis_rules('embed', 'kv')
287-
),
288-
bias_init=nnx.with_partitioning(
289-
config.bias_init, config.axis_rules('embed')
290-
),
291-
use_bias=False,
292-
broadcast_dropout=False,
293-
dropout_rate=config.attention_dropout_rate,
294-
rngs=rngs,
295-
keep_rngs=False,
299+
num_heads=config.num_heads,
300+
in_features=config.emb_dim,
301+
qkv_features=config.qkv_dim,
302+
dtype=config.dtype,
303+
kernel_init=maybe_with_partitioning(
304+
config.kernel_init, config.axis_rules('embed', 'kv'), config
305+
),
306+
bias_init=maybe_with_partitioning(
307+
config.bias_init, config.axis_rules('embed'), config
308+
),
309+
use_bias=False,
310+
broadcast_dropout=False,
311+
dropout_rate=config.attention_dropout_rate,
312+
rngs=rngs,
313+
keep_rngs=False,
296314
)
297315
self.mlp = MlpBlock(config=config, rngs=rngs)
298316
self.dropout = nnx.Dropout(rate=config.dropout_rate)
@@ -348,13 +366,14 @@ def __init__(
348366
# Target Embedding
349367
if self.shared_embedding is None:
350368
self.output_embed = nnx.Embed(
351-
num_embeddings=config.output_vocab_size,
352-
features=config.emb_dim,
353-
embedding_init=nnx.with_partitioning(
354-
nnx.initializers.normal(stddev=1.0),
355-
config.axis_rules('vocab', 'embed'),
356-
),
357-
rngs=rngs,
369+
num_embeddings=config.output_vocab_size,
370+
features=config.emb_dim,
371+
embedding_init=maybe_with_partitioning(
372+
nnx.initializers.normal(stddev=1.0),
373+
config.axis_rules('vocab', 'embed'),
374+
config,
375+
),
376+
rngs=rngs,
358377
)
359378
else:
360379
self.output_embed = self.shared_embedding
@@ -366,28 +385,28 @@ def __init__(
366385
setattr(self, f'encoderdecoderblock_{idx}', layer)
367386

368387
self.encoderdecoder_norm = nnx.LayerNorm(
369-
num_features=config.emb_dim,
370-
dtype=config.dtype,
371-
bias_init=nnx.with_partitioning(
372-
nnx.initializers.zeros_init(), config.axis_rules('embed')
373-
),
374-
scale_init=nnx.with_partitioning(
375-
nnx.initializers.ones_init(), config.axis_rules('embed')
376-
),
377-
rngs=rngs,
378-
)
379-
if not config.logits_via_embedding:
380-
self.logitdense = nnx.Linear(
381-
in_features=config.emb_dim,
382-
out_features=config.output_vocab_size,
388+
num_features=config.emb_dim,
383389
dtype=config.dtype,
384-
kernel_init=nnx.with_partitioning(
385-
config.kernel_init, config.axis_rules('embed', 'vocab')
390+
bias_init=maybe_with_partitioning(
391+
nnx.initializers.zeros_init(), config.axis_rules('embed'), config
386392
),
387-
bias_init=nnx.with_partitioning(
388-
config.bias_init, config.axis_rules('vocab')
393+
scale_init=maybe_with_partitioning(
394+
nnx.initializers.ones_init(), config.axis_rules('embed'), config
389395
),
390396
rngs=rngs,
397+
)
398+
if not config.logits_via_embedding:
399+
self.logitdense = nnx.Linear(
400+
in_features=config.emb_dim,
401+
out_features=config.output_vocab_size,
402+
dtype=config.dtype,
403+
kernel_init=maybe_with_partitioning(
404+
config.kernel_init, config.axis_rules('embed', 'vocab'), config
405+
),
406+
bias_init=maybe_with_partitioning(
407+
config.bias_init, config.axis_rules('vocab'), config
408+
),
409+
rngs=rngs,
391410
)
392411
else:
393412
self.logitdense = None

examples/lm1b_nnx/models_test.py

Lines changed: 31 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,8 @@
3333

3434
jax.config.update('jax_disable_most_optimizations', True)
3535

36-
# add project_root to import lm1b Linen model
37-
# "/path/to/flax/examples/lm1b_nnx/models_test.py" -> "/path/to/flax"
38-
project_root = str(Path(__file__).absolute().parents[2])
39-
sys.path.append(project_root)
40-
from examples.lm1b.models import TransformerLM as TransformerLinen # type: ignore[import-error]
41-
42-
sys.path.pop()
36+
# Import lm1b Linen model for compatibility testing
37+
from flax.examples.lm1b.models import TransformerLM as TransformerLinen
4338

4439

4540
@dataclasses.dataclass(unsafe_hash=True)
@@ -51,23 +46,23 @@ class CompatTransformerConfig(TransformerConfig):
5146
def get_transformer_config(**kwargs):
5247
base_config = default.get_config()
5348
config = CompatTransformerConfig(
54-
vocab_size=base_config.vocab_size,
55-
output_vocab_size=base_config.vocab_size,
56-
logits_via_embedding=base_config.logits_via_embedding,
57-
dtype=jnp.bfloat16 if base_config.use_bfloat16 else jnp.float32,
58-
emb_dim=base_config.emb_dim,
59-
num_heads=base_config.num_heads,
60-
num_layers=base_config.num_layers,
61-
qkv_dim=base_config.qkv_dim,
62-
mlp_dim=base_config.mlp_dim,
63-
max_len=max(
64-
base_config.max_target_length, base_config.max_eval_target_length
65-
),
66-
dropout_rate=base_config.dropout_rate,
67-
attention_dropout_rate=base_config.attention_dropout_rate,
68-
kernel_init=nnx.initializers.xavier_uniform(),
69-
bias_init=nnx.initializers.normal(stddev=1e-6),
70-
**kwargs,
49+
vocab_size=base_config.vocab_size,
50+
output_vocab_size=base_config.vocab_size,
51+
logits_via_embedding=base_config.logits_via_embedding,
52+
dtype=jnp.bfloat16 if base_config.use_bfloat16 else jnp.float32,
53+
emb_dim=base_config.emb_dim,
54+
num_heads=base_config.num_heads,
55+
num_layers=base_config.num_layers,
56+
qkv_dim=base_config.qkv_dim,
57+
mlp_dim=base_config.mlp_dim,
58+
max_len=max(
59+
base_config.max_target_length, base_config.max_eval_target_length
60+
),
61+
dropout_rate=base_config.dropout_rate,
62+
attention_dropout_rate=base_config.attention_dropout_rate,
63+
kernel_init=nnx.initializers.xavier_uniform(),
64+
bias_init=nnx.initializers.normal(stddev=1e-6),
65+
**kwargs,
7166
)
7267
return base_config, config
7368

@@ -93,9 +88,10 @@ def copy_var(nnx_name: str, linen_name: str):
9388
== flat_params_linen[linen_name].value.shape
9489
)
9590
flat_params_nnx[nnx_path].value = flat_params_linen[linen_name].value
96-
assert flat_params_nnx[nnx_path].sharding == apply_rules(
97-
flat_params_linen[linen_name].names
98-
)
91+
if not all(rule is None for rule in rules.values()):
92+
assert flat_params_nnx[nnx_path].sharding == apply_rules(
93+
flat_params_linen[linen_name].names
94+
)
9995

10096
copy_var('decoder/output_embed/embedding', 'decoder/Embed_0/embedding')
10197
copy_var(
@@ -196,14 +192,8 @@ def copy_var(nnx_name: str, linen_name: str):
196192

197193
def test_forward_eval(self):
198194
_, config = get_transformer_config(
199-
axis_rules=default.MeshRules(
200-
embed='model',
201-
mlp='data',
202-
kv=None,
203-
vocab=None,
204-
),
205-
deterministic=True,
206-
decode=False,
195+
deterministic=True,
196+
decode=False,
207197
)
208198
# Set dropout rates to avoid create dropout states
209199
config.dropout_rate = 0.0
@@ -233,14 +223,8 @@ def test_forward_decode(self):
233223
batch_size = 2
234224

235225
_, config = get_transformer_config(
236-
axis_rules=default.MeshRules(
237-
embed='model',
238-
mlp='data',
239-
kv=None,
240-
vocab=None,
241-
),
242-
deterministic=True,
243-
decode=True,
226+
deterministic=True,
227+
decode=True,
244228
)
245229
# Set dropout rates to avoid create dropout states
246230
config.dropout_rate = 0.0
@@ -293,14 +277,8 @@ def test_forward_decode(self):
293277

294278
def test_forward_eval_set_mode(self):
295279
_, config = get_transformer_config(
296-
axis_rules=default.MeshRules(
297-
embed='model',
298-
mlp='data',
299-
kv=None,
300-
vocab=None,
301-
),
302-
deterministic=True,
303-
decode=False,
280+
deterministic=True,
281+
decode=False,
304282
)
305283
# Set dropout rates to avoid create dropout states
306284
config.dropout_rate = 0.0
@@ -330,14 +308,8 @@ def test_forward_decode_set_mode(self):
330308
batch_size = 2
331309

332310
_, config = get_transformer_config(
333-
axis_rules=default.MeshRules(
334-
embed='model',
335-
mlp='data',
336-
kv=None,
337-
vocab=None,
338-
),
339-
deterministic=True,
340-
decode=True,
311+
deterministic=True,
312+
decode=True,
341313
)
342314
# Set dropout rates to avoid create dropout states
343315
config.dropout_rate = 0.0

0 commit comments

Comments
 (0)