Add meshes and config for TRN2/1 for Fuji models#885
Add meshes and config for TRN2/1 for Fuji models#885apoorvtintin wants to merge 1 commit intoapple:mainfrom
Conversation
| mesh_rules=( | ||
| ( | ||
| "neuron-(trn2|trn2n).48xlarge-64", | ||
| mesh_shape_from_axes(fsdp=-1, model=4), |
There was a problem hiding this comment.
Comment on why we set model=4 for neuron?
| if num_kv_heads: | ||
| atten_cfg = GroupedQueryAttention.default_config() | ||
| atten_input_linear = FusedGroupedQKVLinear.default_config().set(num_kv_heads=num_kv_heads) | ||
| backend = jax.default_backend() |
There was a problem hiding this comment.
The fuji config should not depend on jax.default_backend(), otherwise the golden configs will not reflect the actual config being used.
Instead, we can create separate configs for a backend that requires different settings.
There was a problem hiding this comment.
+1, please follow this example instead if you really need to overwrite some configs, you can add another custom LayerConfigModifierlike this one: https://github.com/apple/axlearn/blob/main/axlearn/common/trainer_config_modifier.py#L69,
There was a problem hiding this comment.
Thanks for the review, will update the PR with a custom LayerConfigModifier.
| raise NotImplementedError(f"Unknown model size {model_size}.") | ||
| model_kwargs = trainer_kwargs.pop("model_kwargs") | ||
| model_kwargs.setdefault("vocab_size", vocab_size) | ||
| model_kwargs.setdefault("stack_cfg", None if backend != "neuron" else StackedTransformerLayer.default_config()) |
There was a problem hiding this comment.
Will the use of StackedTransformerLayer (vs. RepeatedTransformerLayer) lead to large XLA programs and long compilation time?
There was a problem hiding this comment.
We are in the middle of optimizing RepeatedTransformer to use a new hardware feature in TRN2 to make dynamic memory operations faster. In the meantime, please continue to use StackedTransformer. Neuron compiler has a module to detect repeating blocks, compile once and reuse. So, compile time won't grow with the number of layers.
There was a problem hiding this comment.
We are in the middle of optimizing RepeatedTransformer to use a new hardware feature in TRN2 to make dynamic memory operations faster. In the meantime, please continue to use StackedTransformer. Neuron compiler has a module to detect repeating blocks, compile once and reuse. So, compile time won't grow with the number of layers.
Nice! Could you add this as a comment?
|
Opened a new PR from my fork of Axlearn (#916). All comments in this discussion have been addressed there. |
|
Closing since there is a new version of the PR in (#916) |
This PR adds meshes for TRN2/1 for Fuji models and transformer layer configuration favorable to Neuron.
Neuron supports stacked transformer and GroupedQKVLinear instead of FusedGroupedQKVLinear for Grouped Query Attention (GQA)