-
Notifications
You must be signed in to change notification settings - Fork 304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate Orbax's emergency checkpoint. #820
Conversation
b4a00eb
to
0e74dc8
Compare
cf43485
to
56f51de
Compare
d5a3e0f
to
bea8b71
Compare
65f3d46
to
c1a476d
Compare
del os.environ["JAX_PLATFORMS"] | ||
|
||
|
||
class OrbaxEmergencyCheckpointer(BaseCheckpointer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this overlap with
axlearn/axlearn/common/checkpointer_orbax.py
Line 169 in 140a18f
class OrbaxCheckpointer(BaseCheckpointer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The orbax regular checkpointer saves 1 gcs checkpoint for n slices per save. This checkpointer saves n-1 checkpoints to a local path (usually a ramdisk), and also 1 checkpoint to gcs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately it's not possible to share code between the two implementations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification.
- How should users use them? Should they use both of them or one of them? How should they pick?
- Or can we replace OrbaxCheckpointer with this class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment to clarify this:
This checkpointer is designed to improve the goodput of large multi-slice training jobs that
use data-parallelism across slices. At least two data-parallel slices are required. For other
use cases where this is not applicable or ultimate goodput is not required, please use
`OrbaxCheckpointer`.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, only the local checkpoints require multiple slices, since we will need to restore from another slice upon a slice restart. Could we disable local checkpoints when num_slices=1? This way we always use the emergency checkpointer consistently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be an idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still needs support from orbax though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Right now it assumes local checkpoints are always used: https://github.com/google/orbax/blob/6e80ecc27581a413b1a481d4740e61df7316a4f4/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py#L695-L709.
Could you raise this request to the Orbax team and link to the issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation of the constraints. I wonder what the long term plan is.
Is the emergency checkpointer a temporary solution that will eventually be dropped when the main Orbax checkpointer supports in-memory checkpoints?
Or will we keep maintaining two separate checkpointers, with potentially incompatible ckpt layouts?
I don't know if Google has such a plan. The orbax in-memory checkpointer actually uses the orbax regular checkpointer under the hood, which might be required by design/by nature of the problem that it solves.
Since in-memory checkpointer uses the regular orbax checkpointer under the hood, the tensor state in the persistent checkpoint (i.e. the one stored to gcs) can be loaded by |
I think in the long term, it's probably possible to unify the checkpoint structure between the two checkpointer (regular and in-memory), but it's unknown whether we can unify the codepath. |
del os.environ["JAX_PLATFORMS"] | ||
|
||
|
||
class OrbaxEmergencyCheckpointer(BaseCheckpointer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, only the local checkpoints require multiple slices, since we will need to restore from another slice upon a slice restart. Could we disable local checkpoints when num_slices=1? This way we always use the emergency checkpointer consistently.
# Note that save() waits for prior serialization to finish. | ||
self._non_tensor_manager.save(step=step, state=state) | ||
self._get_tensor_manager(state_with_tensors).save( | ||
step=step, args=ocp.args.PyTreeSave(item=state_with_tensors) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we mark the completion of a checkpoint? It should happen only when both tensor and non-tensor states are saved. How is this ensured?
Please add a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no special marker for completion of both. There are only markers for each of them individually. So, during restore, we look for both of them only load a specific step when both marker exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Can you add a comment here and point to where "we look for both of them only load a specific step when both marker exists"?
Do we have testing for incomplete checkpoints?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add a testcase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@ruomingp I guess my question now is what's the plan here. Should we wait for Orbax's support for non tensor states and unified checkpointer API? I personally don't see in-mem ckpt as a life-changing feature, so waiting could be an viable option. Alternatively, we can proceed with this PR and make changes later. |
There's no best solution. I see three possibilities: I do not think we want to maintain two Orbax checkpointers in the longer run, especially with incompatible layouts. WDYT? |
# Note that save() waits for prior serialization to finish. | ||
self._non_tensor_manager.save(step=step, state=state) | ||
self._get_tensor_manager(state_with_tensors).save( | ||
step=step, args=ocp.args.PyTreeSave(item=state_with_tensors) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Can you add a comment here and point to where "we look for both of them only load a specific step when both marker exists"?
Do we have testing for incomplete checkpoints?
global_mesh=thread_resources.env.physical_mesh, | ||
abstract_state=self._get_abstract_state(state_with_tensors), | ||
options=oecp.CheckpointManagerOptions( | ||
local=oecp.LocalCheckpointOptions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we expose oecp.LocalCheckpointOptions
to users as Config.local_checkpoint_options
? User can set it to None to disable local checkpoints.
We can provide a helper function for users to construct should_save_fn
from their policy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's already exposed, via local_keep_last_n
and local_save_policy
# Find the intersection of the checkpoint steps managed by tensor and non-tensor | ||
# manager, and then use the latest step in the intersection for restore. `all_steps` | ||
# from tensor manager contains both local and persistent checkpoints. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider refactoring this logic to a separate function so that it can be tested directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done and added a test.
del os.environ["JAX_PLATFORMS"] | ||
|
||
|
||
class OrbaxEmergencyCheckpointer(BaseCheckpointer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Right now it assumes local checkpoints are always used: https://github.com/google/orbax/blob/6e80ecc27581a413b1a481d4740e61df7316a4f4/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py#L695-L709.
Could you raise this request to the Orbax team and link to the issue?
I think maybe waiting may not be the best idea, given that Orbax dev is often delayed, it has been for over a year since last time Anthropic raised this in-mem checkpointing feature request. I think both options below works for me
Maybe let's merge it for testing (we need quite sometime to test throughout), and then wait for a quarter and see if Orbax can really make it happen and available to our users? if so we can discard the tf iterator handling and use the upstream unified one. If not, then we can commit to unification ourselves? |
@kelvin-zou SGTM. |
b702c54
to
059dfa2
Compare
@ruomingp Could you please approve this PR if it looks good? |
059dfa2
to
939409b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving to unblock experiments.
"""Checkpointer implementation that uses Orbax emergency checkpoint. | ||
|
||
## Summary: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Checkpointer implementation that uses Orbax emergency checkpoint. | |
## Summary: | |
"""Checkpointer implementation that uses Orbax emergency checkpoint. | |
EXPERIMENTAL. Do not use for actual training runs since the checkpoint layout will likely change in the future. | |
## Summary: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
persistent_directory=os.path.join(cfg.dir, self._TENSORS_PREFIX), | ||
global_mesh=thread_resources.env.physical_mesh, | ||
abstract_state=self._get_abstract_state(state_with_tensors), | ||
options=oecp.CheckpointManagerOptions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
local_state_handler
is missing here.
CheckpointManager::init does not have default value for local_state_handler
def __init__(
self,
local_directory: epath.PathLike,
persistent_directory: epath.PathLike,
global_mesh: jax.sharding.Mesh,
abstract_state: PyTree, # a single PyTree describing the state structure
# TODO: b/330585086 - Support arbitrary items beyond state. We will have
# to evaluate whether arbitrary items can be a good fit for local
# checkpointing, given restore+broadcast requirements.
local_state_handler: CheckpointHandler,
*,
options: Optional[CheckpointManagerOptions] = None,
metadata: Optional[dict[str, Any]] = None,
logger: Optional[abstract_logger.AbstractLogger] = None,
):
pre-commit and pytype fail too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, upgrading to orbax-checkpoint-0.11.1 solved the issue
commit 336c75d Author: Mark Lee <[email protected]> Date: Mon Mar 3 09:04:07 2025 -0800 Supports arbitrary uniform partitioning in host-global array conversions. (apple#1029) * Allows specifying PartitionSpec to host_to_global_device_array. * Generalizes to arbitrary uniform partitioning. * Addresses comments and adds mixed shape test. commit 0881412 Author: Dongseong Hwang <[email protected]> Date: Sat Mar 1 15:41:38 2025 -0800 Refactor Mask in Attention (apple#1028) Currently, the attention code is **hardcoded** to handle either `causal_mask` or an arbitrary `mask_fn`. To support **sliding window masks**, we previously used a **hack** by injecting the `_sliding_window_size` attribute into functions. This refactor **makes the masking logic more flexible** by allowing arbitrary `MaskFnAttentionBias`. - If downstream requires a **new mask pattern**, they can simply: 1. Implement a **subclass of `MaskFnAttentionBias`**. 2. Set `attention.mask` accordingly. commit f67d3f9 Author: Dongseong Hwang <[email protected]> Date: Fri Feb 28 08:53:00 2025 -0800 Flash Attention now explicitly checks whether it is in decoding mode. (apple#1026) Currently, Flash Attention infers decoding implicitly based on circumstantial evidence. This PR makes the check explicit. commit f8d2c66 Author: qdavid1 <[email protected]> Date: Thu Feb 27 15:26:18 2025 -0800 External KV input for _update_layer_kwargs (apple#1025) commit a3bf5e2 Author: Hanzhi Zhou <[email protected]> Date: Wed Feb 26 17:23:40 2025 -0800 Minor changes to Checkpointer (apple#1024) commit 55e1841 Author: Wentao Wu <[email protected]> Date: Wed Feb 26 15:45:51 2025 -0800 Add an option to break ties for top_k_logits when k = 1 (apple#1022) * Add an option to support stable top_k = 1. * address comments * address comments * address comments * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <[email protected]> * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <[email protected]> * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <[email protected]> * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <[email protected]> * address comments --------- Co-authored-by: Mark Lee <[email protected]> commit fbca3fc Author: Meng (Ethan) Li <[email protected]> Date: Wed Feb 26 14:05:25 2025 -0800 Add priority_class as a launch flag (apple#1020) commit b26bd74 Author: Meng (Ethan) Li <[email protected]> Date: Wed Feb 26 14:04:47 2025 -0800 Fix TypeError in calcualte_goodput.py (apple#1023) commit f8191e1 Author: Dongseong Hwang <[email protected]> Date: Wed Feb 26 11:03:44 2025 -0800 Emulate flash attentnion unittests on CPU. (apple#1021) utils.py codebase is not well covered by CI because it branches different backend. This PR introduces new CPU test, utils_test.py. This test is expected to run on CPU and is designed to validate GPU/TPU code from a CPU environment by fake mesh. It allows quick verification in CI and local environments to ensure that code changes do not break GPU/TPU Flash Attention. commit daec8c5 Author: Chang Liu <[email protected]> Date: Tue Feb 25 12:38:43 2025 -0800 Add additional_network and additional_subnetwork config to support multi-nic for v6e (apple#1019) Co-authored-by: Chang Liu <[email protected]> commit ac642ea Author: Dongseong Hwang <[email protected]> Date: Tue Feb 25 12:02:57 2025 -0800 Fix crash in log-mel frontend when waveform samples are integers. (apple#1017) After updating JAX, this existing hidden bug started causing CI failures. When the sample dtype is int32 (which is valid), `jnp.finfo` returns None, even though `jnp.iinfo` is available. The previous JAX version seemed to handle this case more forgivingly. ``` ../axlearn/axlearn/audio/frontend_utils.py:297: in linear_to_log_spectrogram return jnp.log(jnp.maximum(x, jnp.finfo(x.dtype).tiny)) ``` commit 7c64b55 Author: Meng (Ethan) Li <[email protected]> Date: Tue Feb 25 10:52:50 2025 -0800 Add LoadBalancer to GKE replicatedJob (apple#1015) Co-authored-by: Liang (SPG) He <[email protected]> commit 8e8a41b Author: Chang Lan <[email protected]> Date: Tue Feb 25 10:26:59 2025 -0800 Expose jax.lax.scan's unroll option to Repeat layer (apple#1016) * Expose jax.lax.scan's unroll option to Repeat layer. * Defaults to None to avoid golden config changes commit 682bce6 Author: Dongseong Hwang <[email protected]> Date: Tue Feb 25 10:09:41 2025 -0800 Handle None bias in BiasAndResidual (apple#1018) commit f053318 Author: Ruoming Pang <[email protected]> Date: Mon Feb 24 11:12:23 2025 -0500 Allows a required value in a config_for_{function,class} to be specified via **kwargs in instantiate(). (apple#1013) * Allows a required value in a ClassConfigBase to be specified via **kwargs in instantiate(). * Allows a required value in a FunctionConfigBase to be specified via **kwargs in instantiate(). commit a93cd1b Author: Luzy <[email protected]> Date: Sat Feb 22 19:55:01 2025 -0500 fix dtype in frontend pre emphasis (apple#1014) commit c1fe2e9 Author: Maggie Zhang <[email protected]> Date: Fri Feb 21 19:07:39 2025 -0800 GoodPut minor fix: only process 0 should start goodput uploader (apple#984) * only process 0 will start goodput uploader * Add unit test commit 4b1fbf0 Author: Chang Lan <[email protected]> Date: Fri Feb 21 13:36:45 2025 -0800 Async context invocation for checkpointing (apple#1012) * Async context invocation supoprt for checkpointer * Add comment * Add comments commit d4cd158 Author: Ruoming Pang <[email protected]> Date: Fri Feb 21 14:22:24 2025 -0500 Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase. (apple#1011) * Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase. This makes it easier for `config_for_function` and `config_for_class` to be used for functions and classes that take args of types not allowed by Config fields, e.g., Tensor. * Fixes pytype. * Addresses review. commit 2ae6e66 Author: Meng (Ethan) Li <[email protected]> Date: Fri Feb 21 09:22:35 2025 -0800 Enable megascale abort on hang or error (apple#1010) * Enable megascale_error_reporter_abort on hang and error by default * Increase threshold to 10m commit ce4b2fb Author: Chunyang Wen <[email protected]> Date: Fri Feb 21 23:12:38 2025 +0800 Add GPU monitor (apple#1006) commit baf8ad7 Author: Dongseong Hwang <[email protected]> Date: Thu Feb 20 19:35:07 2025 -0800 Clarify setting sliding_window_size = 8 results in a window size of 9, including itself. (apple#1009) commit cf41112 Author: Hanzhi Zhou <[email protected]> Date: Thu Feb 20 16:29:13 2025 -0800 Partially reverts "gRPC Checkpointer (apple#1005)" (apple#1008) * Revert "gRPC Checkpointer (apple#1005)" This reverts commit d27c562. * Keep some changes commit 454bdba Author: Matthew Hopkins <[email protected]> Date: Thu Feb 20 15:10:51 2025 -0800 upgrade jax 0.4.38 (apple#1007) commit d27c562 Author: Hanzhi Zhou <[email protected]> Date: Tue Feb 18 18:54:53 2025 -0800 gRPC Checkpointer (apple#1005) commit fb90620 Author: Ruoming Pang <[email protected]> Date: Tue Feb 18 21:08:38 2025 -0500 Makes file_system.glob support multiple patterns. (apple#1003) * Makes file_system.glob support multiple patterns. * Makes file_system.glob support multiple patterns. * Makes file_system.glob support multiple patterns. * Makes file_system.glob support multiple patterns. commit 334f421 Author: Mark Lee <[email protected]> Date: Tue Feb 18 17:03:39 2025 -0800 Reverts sliding window attention changes. (apple#1004) * Revert "Fix flash decoding in GPU. (apple#999)" This reverts commit fdadfd8. * Revert "Supports TPU context parallel training (apple#981)" This reverts commit e151d69. * Revert "Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995)" This reverts commit 67645d0. * Retain model/decoder asr changes. commit 3dacc6b Author: Chang Lan <[email protected]> Date: Mon Feb 17 19:45:00 2025 -0800 Refactor aot_compilation for reuse (apple#1000) commit c44fe18 Author: Ruoming Pang <[email protected]> Date: Mon Feb 17 21:38:17 2025 -0500 Makes checkpointer_test.py use file_system. (apple#1001) commit fdadfd8 Author: Dongseong Hwang <[email protected]> Date: Mon Feb 17 18:23:21 2025 -0800 Fix flash decoding in GPU. (apple#999) target_positions used to be time_step, but after PR apple#995, it now represents the actual target positions with shape [batch, step_len]. apple#995 Updating the GPU decoding code to align with this change. CI did not cover GPU unit tests. TEST=test_extend_step10 of axlearn/common/flash_attention/layer_test.py in GPU commit 9e64388 Author: Ruoming Pang <[email protected]> Date: Mon Feb 17 16:40:03 2025 -0500 Makes axlearn/cloud/ use file_system. (apple#998) * Makes bastion.py use file_system. This is a first step towards removing the tf.io.gfile dependency. * Adds testing for file_system.readfile. * Fixes pytype. * Makes axlearn/cloud use file_system instead of gfile. commit 5fba4ce Author: Chang Lan <[email protected]> Date: Mon Feb 17 09:44:10 2025 -0800 AOT compilation support for inference (apple#997) * Add optional `devices` init argument to InferenceRunner for passing fake devices during AOT compilation. * Add more v5e slice types. commit e151d69 Author: Hanzhi Zhou <[email protected]> Date: Sun Feb 16 13:01:15 2025 -0800 Supports TPU context parallel training (apple#981) Fix Fix tests commit 67645d0 Author: Dongseong Hwang <[email protected]> Date: Sat Feb 15 13:26:51 2025 -0800 Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995) * Revert "Transpose kv cache for better decode performance (apple#979)" This reverts commit b130416. * Update golden configs * Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. Currently, when using `MultiheadAttention` or `GroupedQueryAttention` for sliding window attention, the KV cache is kept for the full sequence length (`seq_len`) instead of the window length (`window_len`). For example, a model with `window_len=1k` and `seq_len=2M` keeps a KV cache for the full 2M tokens. It then biases 1999k invalid KV tokens before calculating attention, resulting in a computational complexity of **O(2M²)** instead of the desired **O(1k²)**. This issue persists even when using flash attention. Flash attention uses the KV cache allocated in HBM as its input. While unnecessary blocks are discarded during computation, the KV cache still occupies HBM inefficiently for the full 2M tokens. To address this, when `MultiheadAttention` detects a sliding window mask, it stores the key-value (KV) cache in a ring buffer inside the input linear layer. As a result, downstream projects using `MultiheadAttention` automatically benefit from efficient KV cache handling in `init_states` and `extend_step`. Additionally, for use cases like local-global attention in LLMs, it is recommended to use sliding window masks for even the global attention as well. For example, if you want to train an LLM with a context length of 8k, you can set the sliding window size to 8k during training. This enables functionally infinite decoding during inference. Accuracy wouldn't be good tho. Note: * query_positions in QKVLinear.forward() was introduced by apple#914. Now it returns to the caller. This PR actually moves from downstream speech/streaming/sliding_window_attention.py * transpose commit 272a4d2 Author: Chang Lan <[email protected]> Date: Fri Feb 14 10:48:21 2025 -0800 Add v5e-8 (apple#994) commit debb46a Author: Mark Lee <[email protected]> Date: Thu Feb 13 22:27:28 2025 -0800 Decouples jobsets from replicated jobs. (apple#991) * Decouples jobsets from replicated jobs. * Address comments. commit 8f2b99d Author: Maggie Zhang <[email protected]> Date: Thu Feb 13 20:49:48 2025 -0800 Add Goodput documentation (apple#989) * Temporarily change checkpointing to every 5 steps * revert local changes * Add example command for goodput usage commit 31e8da0 Author: Alexander Pivovarov <[email protected]> Date: Thu Feb 13 15:44:51 2025 -0800 Fix Missing return statement in base_layer_test.py::ExplicitFanLayer::_compute_fan_axes (apple#987) commit 7f2dd9e Author: Apoorv Gupta <[email protected]> Date: Thu Feb 13 14:36:11 2025 -0800 Flash Attention for Neuron (apple#939) commit 6ca4f56 Author: Philipp Dufter <[email protected]> Date: Thu Feb 13 23:09:28 2025 +0100 pass on log_warning in input_tf_data.skip_on_error (apple#990) * make log_warnings customizable in tfds skip error * address comments commit 1a8a0eb Author: Hanzhi Zhou <[email protected]> Date: Thu Feb 13 13:32:34 2025 -0800 Integrate Orbax's emergency checkpoint. (apple#820) * Integrate Orbax emergency checkpoint * Address comments * comment * Address comments * Upgrade orbax * Improve comments * Improve comments * Update for new orbax versions * Better timer * Address comments * Add step test * Fix * Add comment commit 42fd715 Author: Apoorv Gupta <[email protected]> Date: Thu Feb 13 09:13:30 2025 -0800 TRN2 Meshes and Configurations (apple#916) * TRN2 Meshes and Configurations * Add get_recursive and set_recursive to ConfigBase. * Use loops inside get/set_recursively + address comments * Update partition spec * Use get_recursively inside set * Move trn2 configs to a helper function. + Fix modifier tests * TRN2 partitionspec supports DP over FSDP and TP * Use for loop in get_recursively * Update Golden Configs commit d47d5ce Author: Haoshuo Huang <[email protected]> Date: Tue Feb 11 18:13:13 2025 -0800 Add support to slice dataset based on proportions. (apple#982) commit ed8f382 Author: Mark Lee <[email protected]> Date: Tue Feb 11 13:22:44 2025 -0800 Allow metrics layers to have state. (apple#978) * Allow metrics layers to have state. * Move BaseLossMetrics to a new file. commit b130416 Author: Chang Lan <[email protected]> Date: Tue Feb 11 00:01:28 2025 -0800 Transpose kv cache for better decode performance (apple#979) commit 48bf488 Author: Haoshuo Huang <[email protected]> Date: Mon Feb 10 22:25:18 2025 -0800 Add support for grain.IterDataset in sampling (apple#980) commit d4b563c Author: Alexander Pivovarov <[email protected]> Date: Mon Feb 10 15:36:22 2025 -0800 Replace jnp.ndarray with Tensor from axlearn.common.utils (apple#973) commit 0666d80 Author: Alexander Pivovarov <[email protected]> Date: Mon Feb 10 15:35:23 2025 -0800 Fix membership checks in tool_use_execution.py (apple#974) commit 2f4763c Author: Alexander Pivovarov <[email protected]> Date: Mon Feb 10 15:31:59 2025 -0800 Remove redundant import logging (apple#975) commit 58dcf33 Author: Hanzhi Zhou <[email protected]> Date: Mon Feb 10 13:41:33 2025 -0800 Enable cudnn dropout (apple#913) commit ae855ed Author: Mark Lee <[email protected]> Date: Mon Feb 10 12:43:50 2025 -0800 Ensures that cache_dtype is respected. (apple#977) commit cfef38b Author: Daniel Swann <[email protected]> Date: Mon Feb 10 10:56:10 2025 -0800 :sparkles: Add cache for CloudBuild API location queries (apple#967) commit 8fd9137 Author: Wei Liu <[email protected]> Date: Sun Feb 9 15:33:53 2025 -0800 Add segment_ids option in DiTAttentionLayer (apple#976) commit e55a404 Author: Chang Lan <[email protected]> Date: Sun Feb 9 04:38:49 2025 -0800 Use broadcasting trick for KV update (apple#972) * Use vmap and dynamic_update_slice for KV update * Broadcasting trick * Simplify the impl per @markblee's suggestion * comments commit b955187 Author: Dongseong Hwang <[email protected]> Date: Fri Feb 7 14:12:48 2025 -0800 Don't keep initial key/value inputs in the KV cache. (apple#968) The current code is weird. It stores the input key/value in the KV cache, but this doesn’t make sense in either init_states or prefill: * init_states: This is not prefill, so key/value should not be stored in the KV cache. * prefill: The extend_step() function overrides this part anyway. Thus, this PR removes this unnecessary and confusing logic. The logic was introduced in apple#860 commit c3d656d Author: zhengdong-zhang <[email protected]> Date: Fri Feb 7 10:18:42 2025 -0800 Refactorization. (apple#963) commit 1c883d8 Author: Zhao Xu <[email protected]> Date: Fri Feb 7 10:02:56 2025 -0800 Support system role when calling the Gemini API. (apple#971) commit ceab4f4 Author: Haoshuo Huang <[email protected]> Date: Thu Feb 6 20:41:07 2025 -0800 Making shared_memory configurable (apple#969) * Making shared_memory configurable * fix eol space commit 323faa3 Author: Meng (Ethan) Li <[email protected]> Date: Thu Feb 6 12:11:28 2025 -0800 Use env id for gcp settings (apple#957) * Use env_id to replace zone as gcp_settings key to support multiple env under the same zone * fall back to zone * address comments * Suppport project in the label filter; always get zone from gcp_setting value instead of return it directly commit 2ec3a02 Author: Chang Lan <[email protected]> Date: Wed Feb 5 22:25:58 2025 -0800 Fix incorrect number of formatting arguments (apple#966) commit d131d3b Author: Nan Du <[email protected]> Date: Mon Feb 3 11:44:00 2025 -0800 Reduce the verbosity of variable norm summaries (apple#965) commit c1c6e29 Author: Kelvin Zou <[email protected]> Date: Fri Jan 31 22:24:39 2025 -0800 Sliding window support for GPU flash attention (apple#962) * snapshot * snapshot * snapshot * remove unexpected change * adding shape commenbt * fix pylint * snapshot commit 0936a17 Author: Mark Lee <[email protected]> Date: Fri Jan 31 13:59:12 2025 -0800 Supports loss_weights and live_targets in metrics. (apple#960) * Supports loss_weights, live_targets, and module sharing in metrics. * Addresses comments. * Explicitly test flatten_metrics=True. commit 7a40f91 Author: Dipannita Shaw <[email protected]> Date: Fri Jan 31 11:45:33 2025 -0800 Add Goodput & Badput recording and monitoring support. (apple#783) * Code clean up * Add more testing * Fix docstrings * Remove recorder calls from trainer for now * Code cleanup gcp/measurement.py Co-authored-by: Ruoming Pang <[email protected]> * Code cleanup common/measurement.py Co-authored-by: Ruoming Pang <[email protected]> * Fix pre commit errors * Adding more tests * Further clean up * Fix a test error --------- Co-authored-by: Ruoming Pang <[email protected]> commit 031a7f3 Author: Mark Lee <[email protected]> Date: Thu Jan 30 20:19:12 2025 -0800 Skipping empty grain batches during unbatch. (apple#961) * Skipping empty grain batches during unbatch. * Use a loop instead of recursion. commit 795da33 Author: Hanzhi Zhou <[email protected]> Date: Thu Jan 30 07:17:16 2025 -0800 Optimizer offloading through weight-only offload (apple#867) * Optimizer offloading * Style fix * Type fix commit b1a1a5a Author: Haoshuo Huang <[email protected]> Date: Wed Jan 29 21:44:15 2025 -0800 Improve gcsfuse io (apple#959) commit d76ef6f Author: Hanzhi Zhou <[email protected]> Date: Wed Jan 29 15:10:13 2025 -0800 SplashAttention performance tuning for v6e (apple#958) * SplashAttention tuning for v6e * Add import to fix pytype errors commit 2d002e3 Author: Hanzhi Zhou <[email protected]> Date: Wed Jan 29 12:07:56 2025 -0800 Use InputDispatcher for fuji models (apple#956) * Use dispatcher * Update golden configs * Remove logical feed indices commit fad264b Author: Mark Lee <[email protected]> Date: Tue Jan 28 10:41:54 2025 -0800 Explicitly pass module outputs to metrics. (apple#953) * Explicitly pass module outputs to metrics. * Support and add checks for module/state updates. * Only flatten summaries. commit 59508e3 Author: Hanzhi Zhou <[email protected]> Date: Tue Jan 28 10:34:52 2025 -0800 Add v6e PCIe overload workaround flag (apple#955) commit 028ecfd Author: Haoshuo Huang <[email protected]> Date: Mon Jan 27 20:54:28 2025 -0800 Fix GCSFUSE flags by setting resource limit. (apple#954) commit 3e2c6dd Author: Matthew Hopkins <[email protected]> Date: Mon Jan 27 14:56:42 2025 -0800 update jax to 0.4.37 (apple#948) update BlockSpec usage in tpu_attention use TYPE_CHECKING for BuildDatasetFn in input_fake add todo for BuildDatasetFn commit b125f00 Author: Hanzhi Zhou <[email protected]> Date: Mon Jan 27 11:29:23 2025 -0800 Add v6e special meshes (apple#952) * Add v6e special mesh * Add v6e special mesh * Fix * Fix commit a854738 Author: Firenze11 <[email protected]> Date: Mon Jan 27 09:17:46 2025 -0800 Allow external positions to be inputed in RoPE embedding layer (apple#926) * Allow external positions to be inputed in RoPE embedding layer Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after `i_proj`. Unlike the implementation of current `RoFormerQKVLinear`, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by `MultiheadAttention`. * Update attention_test.py * Update dit.py * Update attention.py * Update attention_test.py * Update attention.py * Update dit.py * Update axlearn/common/attention.py Co-authored-by: Mark Lee <[email protected]> * respond to comments. Co-authored-by: Ruoming Pang <[email protected]> * Update attention.py * Update attention.py * Update attention.py --------- Co-authored-by: Mark Lee <[email protected]> Co-authored-by: Ruoming Pang <[email protected]> commit 999401a Author: qdavid1 <[email protected]> Date: Mon Jan 27 09:11:17 2025 -0800 Update LoraFusedQKVLinear (apple#949) commit 1c22688 Author: Mark Lee <[email protected]> Date: Sun Jan 26 04:51:02 2025 -0800 Workaround module outputs being dropped. (apple#951) commit 94c81cb Author: Meng (Ethan) Li <[email protected]> Date: Fri Jan 24 11:01:45 2025 -0800 Add link to github issue regarding kubernetes-32.0.0 (apple#947) commit a6e0f4a Author: Meng (Ethan) Li <[email protected]> Date: Fri Jan 24 08:40:25 2025 -0800 Pin kubernetes pip version to 31.0.0 to fix client authentication error (apple#946) commit 076521a Author: Mark Lee <[email protected]> Date: Thu Jan 23 15:11:00 2025 -0800 Forward input keys to decoder. (apple#944) commit 30284c8 Author: Hanzhi Zhou <[email protected]> Date: Thu Jan 23 10:33:54 2025 -0800 Legacy flash remat fix (apple#943) * Fix the same problem for legacy tpu attn * Fix commit 6a9f980 Author: Mark Lee <[email protected]> Date: Thu Jan 23 09:20:46 2025 -0800 Adds mesh rule for a3-megagpu-8g. (apple#936) commit ac7a3ed Author: Dongseong Hwang <[email protected]> Date: Thu Jan 23 08:15:27 2025 -0800 Enabled running Pallas Flash Attention on CPU. (apple#922) Pallas supports CPU simulation (`interpret=True`), so we can use the same TPU Pallas kernel on CPU — making code debugging easier. This change lets the following unittests run on CPU as if they were on TPU, enabling easier testing and debugging: - `axlearn/common/flash_attention/tpu_attention_test.py` Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU. - `axlearn/common/flash_attention/gpu_attention_test.py` Now CI covers those tests on CPU as well. In M3 Max MacBook Pro, test coverages and processing time are as follows, * axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20) * axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s commit 8ea85bd Author: Hanzhi Zhou <[email protected]> Date: Wed Jan 22 09:51:15 2025 -0800 Some fixes for flash remat (apple#942) commit 185b1b5 Author: Chang Lan <[email protected]> Date: Tue Jan 21 11:21:08 2025 -0800 Repeat KV heads in Flash Attention (apple#938) * Roll back '_repeat_kv_heads' change in Flash Attention Recent PR removed _repeat_kv_heads from Flash Attention for GQA optimization, in the hope to reduce HBM usage. However the actual HBM saving would be limited in the model-parallel setting, as the heads are already sharded across devices. It also introduces some limitation which breaks some of the existing sharding configurations. For example, let's say num_heads = 8 and num_kv_heads = 4. When we repeat KV heads, we can set the model axis as 8 so that each device will have only one Q, K, V head; Without repeat_kv_heads, the max value of model axis is 4, and each device will have 2 Q heads as a result, increasing the actual HBM usage. * Repeat kv as necessary for sharding * Unit tests * Address comments. commit 4678740 Author: Chang Lan <[email protected]> Date: Mon Jan 20 20:36:44 2025 -0800 AOT compilation for v6e (apple#937) commit 357bef6 Author: Mark Lee <[email protected]> Date: Mon Jan 20 20:23:39 2025 -0800 Makes causal lm metrics configurable. (apple#934) * Makes causal lm metrics configurable. * Address review comments. * Make metrics required. * Update golden configs. * Removes PredictModel. commit 16ca0c2 Author: Mark Lee <[email protected]> Date: Sun Jan 19 14:19:20 2025 -0800 Supports flexible input partition specs. (apple#933) * Supports flexible input partition specs in causal lm. * Moves the input partitioning to Input. * Adds missing pytest marker. * Address review comments. * Rebase and update golden configs. * Fixes batch axis names and adds a test. commit 9b75ef1 Author: Mark Lee <[email protected]> Date: Sun Jan 19 07:43:19 2025 -0800 Avoid a top-level import of tokenizers. (apple#935) commit 9996f34 Author: sychen52 <[email protected]> Date: Sat Jan 18 09:44:04 2025 -0800 Add llama 3 tokenizer (apple#850) * Add llama 3 tokenizer add a new version called V3_TIKTOKEN. other edits based on suggestions. * Handle special tokens like other vocabularies. * use encode instead of encode_batch commit ad14de3 Author: Haoshuo Huang <[email protected]> Date: Fri Jan 17 14:19:24 2025 -0800 Add ReadOptions args to _make_autoregressive_inputs (apple#931) * Add ReadOptions args to _make_autoregressive_inputs * use read_options as args instead commit 4858070 Author: Sam Stoelinga <[email protected]> Date: Fri Jan 17 13:54:05 2025 -0800 improve GCS perf: Change resource limit to request (apple#851) commit b0ee05e Author: Bailin <[email protected]> Date: Fri Jan 17 22:53:00 2025 +0800 Add Mamab2 and its Jamba variant (apple#839) * add mamab2 * merge * unify init and prefill * adapt final changes --------- Co-authored-by: bailin_wang <[email protected]> commit 1e25e4a Author: Hanzhi Zhou <[email protected]> Date: Thu Jan 16 11:25:24 2025 -0800 Cache AoT compilation result (apple#927) * Cache AoT compilation result * Fix comments * Fix * Fix * Fix * Fix
commit 336c75d Author: Mark Lee <[email protected]> Date: Mon Mar 3 09:04:07 2025 -0800 Supports arbitrary uniform partitioning in host-global array conversions. (apple#1029) * Allows specifying PartitionSpec to host_to_global_device_array. * Generalizes to arbitrary uniform partitioning. * Addresses comments and adds mixed shape test. commit 0881412 Author: Dongseong Hwang <[email protected]> Date: Sat Mar 1 15:41:38 2025 -0800 Refactor Mask in Attention (apple#1028) Currently, the attention code is **hardcoded** to handle either `causal_mask` or an arbitrary `mask_fn`. To support **sliding window masks**, we previously used a **hack** by injecting the `_sliding_window_size` attribute into functions. This refactor **makes the masking logic more flexible** by allowing arbitrary `MaskFnAttentionBias`. - If downstream requires a **new mask pattern**, they can simply: 1. Implement a **subclass of `MaskFnAttentionBias`**. 2. Set `attention.mask` accordingly. commit f67d3f9 Author: Dongseong Hwang <[email protected]> Date: Fri Feb 28 08:53:00 2025 -0800 Flash Attention now explicitly checks whether it is in decoding mode. (apple#1026) Currently, Flash Attention infers decoding implicitly based on circumstantial evidence. This PR makes the check explicit. commit f8d2c66 Author: qdavid1 <[email protected]> Date: Thu Feb 27 15:26:18 2025 -0800 External KV input for _update_layer_kwargs (apple#1025) commit a3bf5e2 Author: Hanzhi Zhou <[email protected]> Date: Wed Feb 26 17:23:40 2025 -0800 Minor changes to Checkpointer (apple#1024) commit 55e1841 Author: Wentao Wu <[email protected]> Date: Wed Feb 26 15:45:51 2025 -0800 Add an option to break ties for top_k_logits when k = 1 (apple#1022) * Add an option to support stable top_k = 1. * address comments * address comments * address comments * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <[email protected]> * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <[email protected]> * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <[email protected]> * Update axlearn/common/logit_modifiers.py Co-authored-by: Mark Lee <[email protected]> * address comments --------- Co-authored-by: Mark Lee <[email protected]> commit fbca3fc Author: Meng (Ethan) Li <[email protected]> Date: Wed Feb 26 14:05:25 2025 -0800 Add priority_class as a launch flag (apple#1020) commit b26bd74 Author: Meng (Ethan) Li <[email protected]> Date: Wed Feb 26 14:04:47 2025 -0800 Fix TypeError in calcualte_goodput.py (apple#1023) commit f8191e1 Author: Dongseong Hwang <[email protected]> Date: Wed Feb 26 11:03:44 2025 -0800 Emulate flash attentnion unittests on CPU. (apple#1021) utils.py codebase is not well covered by CI because it branches different backend. This PR introduces new CPU test, utils_test.py. This test is expected to run on CPU and is designed to validate GPU/TPU code from a CPU environment by fake mesh. It allows quick verification in CI and local environments to ensure that code changes do not break GPU/TPU Flash Attention. commit daec8c5 Author: Chang Liu <[email protected]> Date: Tue Feb 25 12:38:43 2025 -0800 Add additional_network and additional_subnetwork config to support multi-nic for v6e (apple#1019) Co-authored-by: Chang Liu <[email protected]> commit ac642ea Author: Dongseong Hwang <[email protected]> Date: Tue Feb 25 12:02:57 2025 -0800 Fix crash in log-mel frontend when waveform samples are integers. (apple#1017) After updating JAX, this existing hidden bug started causing CI failures. When the sample dtype is int32 (which is valid), `jnp.finfo` returns None, even though `jnp.iinfo` is available. The previous JAX version seemed to handle this case more forgivingly. ``` ../axlearn/axlearn/audio/frontend_utils.py:297: in linear_to_log_spectrogram return jnp.log(jnp.maximum(x, jnp.finfo(x.dtype).tiny)) ``` commit 7c64b55 Author: Meng (Ethan) Li <[email protected]> Date: Tue Feb 25 10:52:50 2025 -0800 Add LoadBalancer to GKE replicatedJob (apple#1015) Co-authored-by: Liang (SPG) He <[email protected]> commit 8e8a41b Author: Chang Lan <[email protected]> Date: Tue Feb 25 10:26:59 2025 -0800 Expose jax.lax.scan's unroll option to Repeat layer (apple#1016) * Expose jax.lax.scan's unroll option to Repeat layer. * Defaults to None to avoid golden config changes commit 682bce6 Author: Dongseong Hwang <[email protected]> Date: Tue Feb 25 10:09:41 2025 -0800 Handle None bias in BiasAndResidual (apple#1018) commit f053318 Author: Ruoming Pang <[email protected]> Date: Mon Feb 24 11:12:23 2025 -0500 Allows a required value in a config_for_{function,class} to be specified via **kwargs in instantiate(). (apple#1013) * Allows a required value in a ClassConfigBase to be specified via **kwargs in instantiate(). * Allows a required value in a FunctionConfigBase to be specified via **kwargs in instantiate(). commit a93cd1b Author: Luzy <[email protected]> Date: Sat Feb 22 19:55:01 2025 -0500 fix dtype in frontend pre emphasis (apple#1014) commit c1fe2e9 Author: Maggie Zhang <[email protected]> Date: Fri Feb 21 19:07:39 2025 -0800 GoodPut minor fix: only process 0 should start goodput uploader (apple#984) * only process 0 will start goodput uploader * Add unit test commit 4b1fbf0 Author: Chang Lan <[email protected]> Date: Fri Feb 21 13:36:45 2025 -0800 Async context invocation for checkpointing (apple#1012) * Async context invocation supoprt for checkpointer * Add comment * Add comments commit d4cd158 Author: Ruoming Pang <[email protected]> Date: Fri Feb 21 14:22:24 2025 -0500 Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase. (apple#1011) * Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase. This makes it easier for `config_for_function` and `config_for_class` to be used for functions and classes that take args of types not allowed by Config fields, e.g., Tensor. * Fixes pytype. * Addresses review. commit 2ae6e66 Author: Meng (Ethan) Li <[email protected]> Date: Fri Feb 21 09:22:35 2025 -0800 Enable megascale abort on hang or error (apple#1010) * Enable megascale_error_reporter_abort on hang and error by default * Increase threshold to 10m commit ce4b2fb Author: Chunyang Wen <[email protected]> Date: Fri Feb 21 23:12:38 2025 +0800 Add GPU monitor (apple#1006) commit baf8ad7 Author: Dongseong Hwang <[email protected]> Date: Thu Feb 20 19:35:07 2025 -0800 Clarify setting sliding_window_size = 8 results in a window size of 9, including itself. (apple#1009) commit cf41112 Author: Hanzhi Zhou <[email protected]> Date: Thu Feb 20 16:29:13 2025 -0800 Partially reverts "gRPC Checkpointer (apple#1005)" (apple#1008) * Revert "gRPC Checkpointer (apple#1005)" This reverts commit d27c562. * Keep some changes commit 454bdba Author: Matthew Hopkins <[email protected]> Date: Thu Feb 20 15:10:51 2025 -0800 upgrade jax 0.4.38 (apple#1007) commit d27c562 Author: Hanzhi Zhou <[email protected]> Date: Tue Feb 18 18:54:53 2025 -0800 gRPC Checkpointer (apple#1005) commit fb90620 Author: Ruoming Pang <[email protected]> Date: Tue Feb 18 21:08:38 2025 -0500 Makes file_system.glob support multiple patterns. (apple#1003) * Makes file_system.glob support multiple patterns. * Makes file_system.glob support multiple patterns. * Makes file_system.glob support multiple patterns. * Makes file_system.glob support multiple patterns. commit 334f421 Author: Mark Lee <[email protected]> Date: Tue Feb 18 17:03:39 2025 -0800 Reverts sliding window attention changes. (apple#1004) * Revert "Fix flash decoding in GPU. (apple#999)" This reverts commit fdadfd8. * Revert "Supports TPU context parallel training (apple#981)" This reverts commit e151d69. * Revert "Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995)" This reverts commit 67645d0. * Retain model/decoder asr changes. commit 3dacc6b Author: Chang Lan <[email protected]> Date: Mon Feb 17 19:45:00 2025 -0800 Refactor aot_compilation for reuse (apple#1000) commit c44fe18 Author: Ruoming Pang <[email protected]> Date: Mon Feb 17 21:38:17 2025 -0500 Makes checkpointer_test.py use file_system. (apple#1001) commit fdadfd8 Author: Dongseong Hwang <[email protected]> Date: Mon Feb 17 18:23:21 2025 -0800 Fix flash decoding in GPU. (apple#999) target_positions used to be time_step, but after PR apple#995, it now represents the actual target positions with shape [batch, step_len]. apple#995 Updating the GPU decoding code to align with this change. CI did not cover GPU unit tests. TEST=test_extend_step10 of axlearn/common/flash_attention/layer_test.py in GPU commit 9e64388 Author: Ruoming Pang <[email protected]> Date: Mon Feb 17 16:40:03 2025 -0500 Makes axlearn/cloud/ use file_system. (apple#998) * Makes bastion.py use file_system. This is a first step towards removing the tf.io.gfile dependency. * Adds testing for file_system.readfile. * Fixes pytype. * Makes axlearn/cloud use file_system instead of gfile. commit 5fba4ce Author: Chang Lan <[email protected]> Date: Mon Feb 17 09:44:10 2025 -0800 AOT compilation support for inference (apple#997) * Add optional `devices` init argument to InferenceRunner for passing fake devices during AOT compilation. * Add more v5e slice types. commit e151d69 Author: Hanzhi Zhou <[email protected]> Date: Sun Feb 16 13:01:15 2025 -0800 Supports TPU context parallel training (apple#981) Fix Fix tests commit 67645d0 Author: Dongseong Hwang <[email protected]> Date: Sat Feb 15 13:26:51 2025 -0800 Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995) * Revert "Transpose kv cache for better decode performance (apple#979)" This reverts commit b130416. * Update golden configs * Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. Currently, when using `MultiheadAttention` or `GroupedQueryAttention` for sliding window attention, the KV cache is kept for the full sequence length (`seq_len`) instead of the window length (`window_len`). For example, a model with `window_len=1k` and `seq_len=2M` keeps a KV cache for the full 2M tokens. It then biases 1999k invalid KV tokens before calculating attention, resulting in a computational complexity of **O(2M²)** instead of the desired **O(1k²)**. This issue persists even when using flash attention. Flash attention uses the KV cache allocated in HBM as its input. While unnecessary blocks are discarded during computation, the KV cache still occupies HBM inefficiently for the full 2M tokens. To address this, when `MultiheadAttention` detects a sliding window mask, it stores the key-value (KV) cache in a ring buffer inside the input linear layer. As a result, downstream projects using `MultiheadAttention` automatically benefit from efficient KV cache handling in `init_states` and `extend_step`. Additionally, for use cases like local-global attention in LLMs, it is recommended to use sliding window masks for even the global attention as well. For example, if you want to train an LLM with a context length of 8k, you can set the sliding window size to 8k during training. This enables functionally infinite decoding during inference. Accuracy wouldn't be good tho. Note: * query_positions in QKVLinear.forward() was introduced by apple#914. Now it returns to the caller. This PR actually moves from downstream speech/streaming/sliding_window_attention.py * transpose commit 272a4d2 Author: Chang Lan <[email protected]> Date: Fri Feb 14 10:48:21 2025 -0800 Add v5e-8 (apple#994) commit debb46a Author: Mark Lee <[email protected]> Date: Thu Feb 13 22:27:28 2025 -0800 Decouples jobsets from replicated jobs. (apple#991) * Decouples jobsets from replicated jobs. * Address comments. commit 8f2b99d Author: Maggie Zhang <[email protected]> Date: Thu Feb 13 20:49:48 2025 -0800 Add Goodput documentation (apple#989) * Temporarily change checkpointing to every 5 steps * revert local changes * Add example command for goodput usage commit 31e8da0 Author: Alexander Pivovarov <[email protected]> Date: Thu Feb 13 15:44:51 2025 -0800 Fix Missing return statement in base_layer_test.py::ExplicitFanLayer::_compute_fan_axes (apple#987) commit 7f2dd9e Author: Apoorv Gupta <[email protected]> Date: Thu Feb 13 14:36:11 2025 -0800 Flash Attention for Neuron (apple#939) commit 6ca4f56 Author: Philipp Dufter <[email protected]> Date: Thu Feb 13 23:09:28 2025 +0100 pass on log_warning in input_tf_data.skip_on_error (apple#990) * make log_warnings customizable in tfds skip error * address comments commit 1a8a0eb Author: Hanzhi Zhou <[email protected]> Date: Thu Feb 13 13:32:34 2025 -0800 Integrate Orbax's emergency checkpoint. (apple#820) * Integrate Orbax emergency checkpoint * Address comments * comment * Address comments * Upgrade orbax * Improve comments * Improve comments * Update for new orbax versions * Better timer * Address comments * Add step test * Fix * Add comment commit 42fd715 Author: Apoorv Gupta <[email protected]> Date: Thu Feb 13 09:13:30 2025 -0800 TRN2 Meshes and Configurations (apple#916) * TRN2 Meshes and Configurations * Add get_recursive and set_recursive to ConfigBase. * Use loops inside get/set_recursively + address comments * Update partition spec * Use get_recursively inside set * Move trn2 configs to a helper function. + Fix modifier tests * TRN2 partitionspec supports DP over FSDP and TP * Use for loop in get_recursively * Update Golden Configs commit d47d5ce Author: Haoshuo Huang <[email protected]> Date: Tue Feb 11 18:13:13 2025 -0800 Add support to slice dataset based on proportions. (apple#982) commit ed8f382 Author: Mark Lee <[email protected]> Date: Tue Feb 11 13:22:44 2025 -0800 Allow metrics layers to have state. (apple#978) * Allow metrics layers to have state. * Move BaseLossMetrics to a new file. commit b130416 Author: Chang Lan <[email protected]> Date: Tue Feb 11 00:01:28 2025 -0800 Transpose kv cache for better decode performance (apple#979) commit 48bf488 Author: Haoshuo Huang <[email protected]> Date: Mon Feb 10 22:25:18 2025 -0800 Add support for grain.IterDataset in sampling (apple#980) commit d4b563c Author: Alexander Pivovarov <[email protected]> Date: Mon Feb 10 15:36:22 2025 -0800 Replace jnp.ndarray with Tensor from axlearn.common.utils (apple#973) commit 0666d80 Author: Alexander Pivovarov <[email protected]> Date: Mon Feb 10 15:35:23 2025 -0800 Fix membership checks in tool_use_execution.py (apple#974) commit 2f4763c Author: Alexander Pivovarov <[email protected]> Date: Mon Feb 10 15:31:59 2025 -0800 Remove redundant import logging (apple#975) commit 58dcf33 Author: Hanzhi Zhou <[email protected]> Date: Mon Feb 10 13:41:33 2025 -0800 Enable cudnn dropout (apple#913) commit ae855ed Author: Mark Lee <[email protected]> Date: Mon Feb 10 12:43:50 2025 -0800 Ensures that cache_dtype is respected. (apple#977) commit cfef38b Author: Daniel Swann <[email protected]> Date: Mon Feb 10 10:56:10 2025 -0800 :sparkles: Add cache for CloudBuild API location queries (apple#967) commit 8fd9137 Author: Wei Liu <[email protected]> Date: Sun Feb 9 15:33:53 2025 -0800 Add segment_ids option in DiTAttentionLayer (apple#976) commit e55a404 Author: Chang Lan <[email protected]> Date: Sun Feb 9 04:38:49 2025 -0800 Use broadcasting trick for KV update (apple#972) * Use vmap and dynamic_update_slice for KV update * Broadcasting trick * Simplify the impl per @markblee's suggestion * comments commit b955187 Author: Dongseong Hwang <[email protected]> Date: Fri Feb 7 14:12:48 2025 -0800 Don't keep initial key/value inputs in the KV cache. (apple#968) The current code is weird. It stores the input key/value in the KV cache, but this doesn’t make sense in either init_states or prefill: * init_states: This is not prefill, so key/value should not be stored in the KV cache. * prefill: The extend_step() function overrides this part anyway. Thus, this PR removes this unnecessary and confusing logic. The logic was introduced in apple#860 commit c3d656d Author: zhengdong-zhang <[email protected]> Date: Fri Feb 7 10:18:42 2025 -0800 Refactorization. (apple#963) commit 1c883d8 Author: Zhao Xu <[email protected]> Date: Fri Feb 7 10:02:56 2025 -0800 Support system role when calling the Gemini API. (apple#971) commit ceab4f4 Author: Haoshuo Huang <[email protected]> Date: Thu Feb 6 20:41:07 2025 -0800 Making shared_memory configurable (apple#969) * Making shared_memory configurable * fix eol space commit 323faa3 Author: Meng (Ethan) Li <[email protected]> Date: Thu Feb 6 12:11:28 2025 -0800 Use env id for gcp settings (apple#957) * Use env_id to replace zone as gcp_settings key to support multiple env under the same zone * fall back to zone * address comments * Suppport project in the label filter; always get zone from gcp_setting value instead of return it directly commit 2ec3a02 Author: Chang Lan <[email protected]> Date: Wed Feb 5 22:25:58 2025 -0800 Fix incorrect number of formatting arguments (apple#966) commit d131d3b Author: Nan Du <[email protected]> Date: Mon Feb 3 11:44:00 2025 -0800 Reduce the verbosity of variable norm summaries (apple#965) commit c1c6e29 Author: Kelvin Zou <[email protected]> Date: Fri Jan 31 22:24:39 2025 -0800 Sliding window support for GPU flash attention (apple#962) * snapshot * snapshot * snapshot * remove unexpected change * adding shape commenbt * fix pylint * snapshot commit 0936a17 Author: Mark Lee <[email protected]> Date: Fri Jan 31 13:59:12 2025 -0800 Supports loss_weights and live_targets in metrics. (apple#960) * Supports loss_weights, live_targets, and module sharing in metrics. * Addresses comments. * Explicitly test flatten_metrics=True. commit 7a40f91 Author: Dipannita Shaw <[email protected]> Date: Fri Jan 31 11:45:33 2025 -0800 Add Goodput & Badput recording and monitoring support. (apple#783) * Code clean up * Add more testing * Fix docstrings * Remove recorder calls from trainer for now * Code cleanup gcp/measurement.py Co-authored-by: Ruoming Pang <[email protected]> * Code cleanup common/measurement.py Co-authored-by: Ruoming Pang <[email protected]> * Fix pre commit errors * Adding more tests * Further clean up * Fix a test error --------- Co-authored-by: Ruoming Pang <[email protected]> commit 031a7f3 Author: Mark Lee <[email protected]> Date: Thu Jan 30 20:19:12 2025 -0800 Skipping empty grain batches during unbatch. (apple#961) * Skipping empty grain batches during unbatch. * Use a loop instead of recursion. commit 795da33 Author: Hanzhi Zhou <[email protected]> Date: Thu Jan 30 07:17:16 2025 -0800 Optimizer offloading through weight-only offload (apple#867) * Optimizer offloading * Style fix * Type fix commit b1a1a5a Author: Haoshuo Huang <[email protected]> Date: Wed Jan 29 21:44:15 2025 -0800 Improve gcsfuse io (apple#959) commit d76ef6f Author: Hanzhi Zhou <[email protected]> Date: Wed Jan 29 15:10:13 2025 -0800 SplashAttention performance tuning for v6e (apple#958) * SplashAttention tuning for v6e * Add import to fix pytype errors commit 2d002e3 Author: Hanzhi Zhou <[email protected]> Date: Wed Jan 29 12:07:56 2025 -0800 Use InputDispatcher for fuji models (apple#956) * Use dispatcher * Update golden configs * Remove logical feed indices commit fad264b Author: Mark Lee <[email protected]> Date: Tue Jan 28 10:41:54 2025 -0800 Explicitly pass module outputs to metrics. (apple#953) * Explicitly pass module outputs to metrics. * Support and add checks for module/state updates. * Only flatten summaries. commit 59508e3 Author: Hanzhi Zhou <[email protected]> Date: Tue Jan 28 10:34:52 2025 -0800 Add v6e PCIe overload workaround flag (apple#955) commit 028ecfd Author: Haoshuo Huang <[email protected]> Date: Mon Jan 27 20:54:28 2025 -0800 Fix GCSFUSE flags by setting resource limit. (apple#954) commit 3e2c6dd Author: Matthew Hopkins <[email protected]> Date: Mon Jan 27 14:56:42 2025 -0800 update jax to 0.4.37 (apple#948) update BlockSpec usage in tpu_attention use TYPE_CHECKING for BuildDatasetFn in input_fake add todo for BuildDatasetFn commit b125f00 Author: Hanzhi Zhou <[email protected]> Date: Mon Jan 27 11:29:23 2025 -0800 Add v6e special meshes (apple#952) * Add v6e special mesh * Add v6e special mesh * Fix * Fix commit a854738 Author: Firenze11 <[email protected]> Date: Mon Jan 27 09:17:46 2025 -0800 Allow external positions to be inputed in RoPE embedding layer (apple#926) * Allow external positions to be inputed in RoPE embedding layer Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after `i_proj`. Unlike the implementation of current `RoFormerQKVLinear`, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by `MultiheadAttention`. * Update attention_test.py * Update dit.py * Update attention.py * Update attention_test.py * Update attention.py * Update dit.py * Update axlearn/common/attention.py Co-authored-by: Mark Lee <[email protected]> * respond to comments. Co-authored-by: Ruoming Pang <[email protected]> * Update attention.py * Update attention.py * Update attention.py --------- Co-authored-by: Mark Lee <[email protected]> Co-authored-by: Ruoming Pang <[email protected]> commit 999401a Author: qdavid1 <[email protected]> Date: Mon Jan 27 09:11:17 2025 -0800 Update LoraFusedQKVLinear (apple#949) commit 1c22688 Author: Mark Lee <[email protected]> Date: Sun Jan 26 04:51:02 2025 -0800 Workaround module outputs being dropped. (apple#951) commit 94c81cb Author: Meng (Ethan) Li <[email protected]> Date: Fri Jan 24 11:01:45 2025 -0800 Add link to github issue regarding kubernetes-32.0.0 (apple#947) commit a6e0f4a Author: Meng (Ethan) Li <[email protected]> Date: Fri Jan 24 08:40:25 2025 -0800 Pin kubernetes pip version to 31.0.0 to fix client authentication error (apple#946) commit 076521a Author: Mark Lee <[email protected]> Date: Thu Jan 23 15:11:00 2025 -0800 Forward input keys to decoder. (apple#944) commit 30284c8 Author: Hanzhi Zhou <[email protected]> Date: Thu Jan 23 10:33:54 2025 -0800 Legacy flash remat fix (apple#943) * Fix the same problem for legacy tpu attn * Fix commit 6a9f980 Author: Mark Lee <[email protected]> Date: Thu Jan 23 09:20:46 2025 -0800 Adds mesh rule for a3-megagpu-8g. (apple#936) commit ac7a3ed Author: Dongseong Hwang <[email protected]> Date: Thu Jan 23 08:15:27 2025 -0800 Enabled running Pallas Flash Attention on CPU. (apple#922) Pallas supports CPU simulation (`interpret=True`), so we can use the same TPU Pallas kernel on CPU — making code debugging easier. This change lets the following unittests run on CPU as if they were on TPU, enabling easier testing and debugging: - `axlearn/common/flash_attention/tpu_attention_test.py` Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU. - `axlearn/common/flash_attention/gpu_attention_test.py` Now CI covers those tests on CPU as well. In M3 Max MacBook Pro, test coverages and processing time are as follows, * axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20) * axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s commit 8ea85bd Author: Hanzhi Zhou <[email protected]> Date: Wed Jan 22 09:51:15 2025 -0800 Some fixes for flash remat (apple#942) commit 185b1b5 Author: Chang Lan <[email protected]> Date: Tue Jan 21 11:21:08 2025 -0800 Repeat KV heads in Flash Attention (apple#938) * Roll back '_repeat_kv_heads' change in Flash Attention Recent PR removed _repeat_kv_heads from Flash Attention for GQA optimization, in the hope to reduce HBM usage. However the actual HBM saving would be limited in the model-parallel setting, as the heads are already sharded across devices. It also introduces some limitation which breaks some of the existing sharding configurations. For example, let's say num_heads = 8 and num_kv_heads = 4. When we repeat KV heads, we can set the model axis as 8 so that each device will have only one Q, K, V head; Without repeat_kv_heads, the max value of model axis is 4, and each device will have 2 Q heads as a result, increasing the actual HBM usage. * Repeat kv as necessary for sharding * Unit tests * Address comments. commit 4678740 Author: Chang Lan <[email protected]> Date: Mon Jan 20 20:36:44 2025 -0800 AOT compilation for v6e (apple#937) commit 357bef6 Author: Mark Lee <[email protected]> Date: Mon Jan 20 20:23:39 2025 -0800 Makes causal lm metrics configurable. (apple#934) * Makes causal lm metrics configurable. * Address review comments. * Make metrics required. * Update golden configs. * Removes PredictModel. commit 16ca0c2 Author: Mark Lee <[email protected]> Date: Sun Jan 19 14:19:20 2025 -0800 Supports flexible input partition specs. (apple#933) * Supports flexible input partition specs in causal lm. * Moves the input partitioning to Input. * Adds missing pytest marker. * Address review comments. * Rebase and update golden configs. * Fixes batch axis names and adds a test. commit 9b75ef1 Author: Mark Lee <[email protected]> Date: Sun Jan 19 07:43:19 2025 -0800 Avoid a top-level import of tokenizers. (apple#935) commit 9996f34 Author: sychen52 <[email protected]> Date: Sat Jan 18 09:44:04 2025 -0800 Add llama 3 tokenizer (apple#850) * Add llama 3 tokenizer add a new version called V3_TIKTOKEN. other edits based on suggestions. * Handle special tokens like other vocabularies. * use encode instead of encode_batch commit ad14de3 Author: Haoshuo Huang <[email protected]> Date: Fri Jan 17 14:19:24 2025 -0800 Add ReadOptions args to _make_autoregressive_inputs (apple#931) * Add ReadOptions args to _make_autoregressive_inputs * use read_options as args instead commit 4858070 Author: Sam Stoelinga <[email protected]> Date: Fri Jan 17 13:54:05 2025 -0800 improve GCS perf: Change resource limit to request (apple#851) commit b0ee05e Author: Bailin <[email protected]> Date: Fri Jan 17 22:53:00 2025 +0800 Add Mamab2 and its Jamba variant (apple#839) * add mamab2 * merge * unify init and prefill * adapt final changes --------- Co-authored-by: bailin_wang <[email protected]> commit 1e25e4a Author: Hanzhi Zhou <[email protected]> Date: Thu Jan 16 11:25:24 2025 -0800 Cache AoT compilation result (apple#927) * Cache AoT compilation result * Fix comments * Fix * Fix * Fix * Fix
No description provided.