Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Orbax's emergency checkpoint. #820

Merged
merged 13 commits into from
Feb 13, 2025
Merged

Conversation

hanzhi713
Copy link
Member

No description provided.

@hanzhi713 hanzhi713 force-pushed the in-mem-ckpt branch 2 times, most recently from b4a00eb to 0e74dc8 Compare November 12, 2024 19:49
@hanzhi713 hanzhi713 force-pushed the in-mem-ckpt branch 2 times, most recently from cf43485 to 56f51de Compare November 19, 2024 00:37
@hanzhi713 hanzhi713 force-pushed the in-mem-ckpt branch 2 times, most recently from d5a3e0f to bea8b71 Compare January 8, 2025 22:45
@hanzhi713 hanzhi713 force-pushed the in-mem-ckpt branch 2 times, most recently from 65f3d46 to c1a476d Compare January 30, 2025 21:16
@hanzhi713 hanzhi713 marked this pull request as ready for review January 30, 2025 23:13
@hanzhi713 hanzhi713 requested review from ruomingp, markblee and a team as code owners January 30, 2025 23:13
del os.environ["JAX_PLATFORMS"]


class OrbaxEmergencyCheckpointer(BaseCheckpointer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this overlap with

class OrbaxCheckpointer(BaseCheckpointer):
?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The orbax regular checkpointer saves 1 gcs checkpoint for n slices per save. This checkpointer saves n-1 checkpoints to a local path (usually a ramdisk), and also 1 checkpoint to gcs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately it's not possible to share code between the two implementations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification.

  • How should users use them? Should they use both of them or one of them? How should they pick?
  • Or can we replace OrbaxCheckpointer with this class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment to clarify this:

    This checkpointer is designed to improve the goodput of large multi-slice training jobs that
    use data-parallelism across slices. At least two data-parallel slices are required. For other
    use cases where this is not applicable or ultimate goodput is not required, please use
    `OrbaxCheckpointer`.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, only the local checkpoints require multiple slices, since we will need to restore from another slice upon a slice restart. Could we disable local checkpoints when num_slices=1? This way we always use the emergency checkpointer consistently.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be an idea.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still needs support from orbax though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Right now it assumes local checkpoints are always used: https://github.com/google/orbax/blob/6e80ecc27581a413b1a481d4740e61df7316a4f4/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py#L695-L709.

Could you raise this request to the Orbax team and link to the issue?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@ruomingp ruomingp self-assigned this Jan 31, 2025
@hanzhi713 hanzhi713 requested a review from ruomingp January 31, 2025 21:35
@ruomingp ruomingp removed their request for review February 1, 2025 02:51
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation of the constraints. I wonder what the long term plan is.

Is the emergency checkpointer a temporary solution that will eventually be dropped when the main Orbax checkpointer supports in-memory checkpoints?

Or will we keep maintaining two separate checkpointers, with potentially incompatible ckpt layouts?

@hanzhi713
Copy link
Member Author

Is the emergency checkpointer a temporary solution that will eventually be dropped when the main Orbax checkpointer supports in-memory checkpoints?

I don't know if Google has such a plan. The orbax in-memory checkpointer actually uses the orbax regular checkpointer under the hood, which might be required by design/by nature of the problem that it solves.

Or will we keep maintaining two separate checkpointers, with potentially incompatible ckpt layouts?

Since in-memory checkpointer uses the regular orbax checkpointer under the hood, the tensor state in the persistent checkpoint (i.e. the one stored to gcs) can be loaded by OrbaxStateBuilder (see #866). Therefore, we can say that the checkpoints are compatible for eval and inference purposes. It's just that the training checkpoint will be incompatible, meaning that OrbaxEmergencyCheckpointer's checkpoint cannot be loaded by OrbaxCheckpointer.

@hanzhi713
Copy link
Member Author

I think in the long term, it's probably possible to unify the checkpoint structure between the two checkpointer (regular and in-memory), but it's unknown whether we can unify the codepath.

del os.environ["JAX_PLATFORMS"]


class OrbaxEmergencyCheckpointer(BaseCheckpointer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, only the local checkpoints require multiple slices, since we will need to restore from another slice upon a slice restart. Could we disable local checkpoints when num_slices=1? This way we always use the emergency checkpointer consistently.

Comment on lines 744 to 752
# Note that save() waits for prior serialization to finish.
self._non_tensor_manager.save(step=step, state=state)
self._get_tensor_manager(state_with_tensors).save(
step=step, args=ocp.args.PyTreeSave(item=state_with_tensors)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we mark the completion of a checkpoint? It should happen only when both tensor and non-tensor states are saved. How is this ensured?

Please add a comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no special marker for completion of both. There are only markers for each of them individually. So, during restore, we look for both of them only load a specific step when both marker exists.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Can you add a comment here and point to where "we look for both of them only load a specific step when both marker exists"?

Do we have testing for incomplete checkpoints?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add a testcase.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@hanzhi713
Copy link
Member Author

@ruomingp I guess my question now is what's the plan here. Should we wait for Orbax's support for non tensor states and unified checkpointer API?

I personally don't see in-mem ckpt as a life-changing feature, so waiting could be an viable option. Alternatively, we can proceed with this PR and make changes later.

@ruomingp
Copy link
Contributor

ruomingp commented Feb 2, 2025

@ruomingp I guess my question now is what's the plan here. Should we wait for Orbax's support for non tensor states and unified checkpointer API?

I personally don't see in-mem ckpt as a life-changing feature, so waiting could be an viable option. Alternatively, we can proceed with this PR and make changes later.

There's no best solution. I see three possibilities:
1 Wait until Orbax supports a Checkpointer that works for both single-slice and multi-slice settings with a unified persistent layout;
1 Merge this PR for testing, but be willing to discard checkpoints if we decide to change the persistent layout later;
1 Merge this PR and commit to unify the checkpointers and build tools to convert checkpoints for users;

I do not think we want to maintain two Orbax checkpointers in the longer run, especially with incompatible layouts.

WDYT?

Comment on lines 744 to 752
# Note that save() waits for prior serialization to finish.
self._non_tensor_manager.save(step=step, state=state)
self._get_tensor_manager(state_with_tensors).save(
step=step, args=ocp.args.PyTreeSave(item=state_with_tensors)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Can you add a comment here and point to where "we look for both of them only load a specific step when both marker exists"?

Do we have testing for incomplete checkpoints?

global_mesh=thread_resources.env.physical_mesh,
abstract_state=self._get_abstract_state(state_with_tensors),
options=oecp.CheckpointManagerOptions(
local=oecp.LocalCheckpointOptions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we expose oecp.LocalCheckpointOptions to users as Config.local_checkpoint_options? User can set it to None to disable local checkpoints.

We can provide a helper function for users to construct should_save_fn from their policy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's already exposed, via local_keep_last_n and local_save_policy

Comment on lines 770 to 772
# Find the intersection of the checkpoint steps managed by tensor and non-tensor
# manager, and then use the latest step in the intersection for restore. `all_steps`
# from tensor manager contains both local and persistent checkpoints.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider refactoring this logic to a separate function so that it can be tested directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done and added a test.

del os.environ["JAX_PLATFORMS"]


class OrbaxEmergencyCheckpointer(BaseCheckpointer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Right now it assumes local checkpoints are always used: https://github.com/google/orbax/blob/6e80ecc27581a413b1a481d4740e61df7316a4f4/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py#L695-L709.

Could you raise this request to the Orbax team and link to the issue?

@kelvin-zou
Copy link
Contributor

@ruomingp I guess my question now is what's the plan here. Should we wait for Orbax's support for non tensor states and unified checkpointer API?
I personally don't see in-mem ckpt as a life-changing feature, so waiting could be an viable option. Alternatively, we can proceed with this PR and make changes later.

There's no best solution. I see three possibilities: 1 Wait until Orbax supports a Checkpointer that works for both single-slice and multi-slice settings with a unified persistent layout; 1 Merge this PR for testing, but be willing to discard checkpoints if we decide to change the persistent layout later; 1 Merge this PR and commit to unify the checkpointers and build tools to convert checkpoints for users;

I do not think we want to maintain two Orbax checkpointers in the longer run, especially with incompatible layouts.

WDYT?

I think maybe waiting may not be the best idea, given that Orbax dev is often delayed, it has been for over a year since last time Anthropic raised this in-mem checkpointing feature request.

I think both options below works for me

  1. Merge this PR for testing, but be willing to discard checkpoints if we decide to change the persistent layout later
  2. Merge this PR and commit to unify the checkpointers and build tools to convert checkpoints for users;

Maybe let's merge it for testing (we need quite sometime to test throughout), and then wait for a quarter and see if Orbax can really make it happen and available to our users? if so we can discard the tf iterator handling and use the upstream unified one. If not, then we can commit to unification ourselves?

@hanzhi713
Copy link
Member Author

@kelvin-zou SGTM.

@hanzhi713 hanzhi713 requested a review from ruomingp February 3, 2025 18:32
@hanzhi713
Copy link
Member Author

@ruomingp Could you please approve this PR if it looks good?

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to unblock experiments.

Comment on lines 478 to 480
"""Checkpointer implementation that uses Orbax emergency checkpoint.

## Summary:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Checkpointer implementation that uses Orbax emergency checkpoint.
## Summary:
"""Checkpointer implementation that uses Orbax emergency checkpoint.
EXPERIMENTAL. Do not use for actual training runs since the checkpoint layout will likely change in the future.
## Summary:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@hanzhi713 hanzhi713 requested a review from ruomingp February 13, 2025 20:52
@hanzhi713 hanzhi713 added this pull request to the merge queue Feb 13, 2025
Merged via the queue into apple:main with commit 1a8a0eb Feb 13, 2025
6 checks passed
@hanzhi713 hanzhi713 deleted the in-mem-ckpt branch February 13, 2025 22:04
persistent_directory=os.path.join(cfg.dir, self._TENSORS_PREFIX),
global_mesh=thread_resources.env.physical_mesh,
abstract_state=self._get_abstract_state(state_with_tensors),
options=oecp.CheckpointManagerOptions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local_state_handler is missing here.
CheckpointManager::init does not have default value for local_state_handler

  def __init__(
      self,
      local_directory: epath.PathLike,
      persistent_directory: epath.PathLike,
      global_mesh: jax.sharding.Mesh,
      abstract_state: PyTree,  # a single PyTree describing the state structure
      # TODO: b/330585086 - Support arbitrary items beyond state. We will have
      # to evaluate whether arbitrary items can be a good fit for local
      # checkpointing, given restore+broadcast requirements.
      local_state_handler: CheckpointHandler,
      *,
      options: Optional[CheckpointManagerOptions] = None,
      metadata: Optional[dict[str, Any]] = None,
      logger: Optional[abstract_logger.AbstractLogger] = None,
  ):

pre-commit and pytype fail too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, upgrading to orbax-checkpoint-0.11.1 solved the issue

rahul003 added a commit to rahul003/axlearn that referenced this pull request Mar 4, 2025
commit 336c75d
Author: Mark Lee <[email protected]>
Date:   Mon Mar 3 09:04:07 2025 -0800

    Supports arbitrary uniform partitioning in host-global array conversions. (apple#1029)

    * Allows specifying PartitionSpec to host_to_global_device_array.

    * Generalizes to arbitrary uniform partitioning.

    * Addresses comments and adds mixed shape test.

commit 0881412
Author: Dongseong Hwang <[email protected]>
Date:   Sat Mar 1 15:41:38 2025 -0800

    Refactor Mask in Attention (apple#1028)

    Currently, the attention code is **hardcoded** to handle either `causal_mask`
    or an arbitrary `mask_fn`.

    To support **sliding window masks**, we previously used a **hack** by injecting
    the `_sliding_window_size` attribute into functions.

    This refactor **makes the masking logic more flexible** by allowing arbitrary
    `MaskFnAttentionBias`.
    - If downstream requires a **new mask pattern**, they can simply:
      1. Implement a **subclass of `MaskFnAttentionBias`**.
      2. Set `attention.mask` accordingly.

commit f67d3f9
Author: Dongseong Hwang <[email protected]>
Date:   Fri Feb 28 08:53:00 2025 -0800

    Flash Attention now explicitly checks whether it is in decoding mode. (apple#1026)

    Currently, Flash Attention infers decoding implicitly based on circumstantial
    evidence. This PR makes the check explicit.

commit f8d2c66
Author: qdavid1 <[email protected]>
Date:   Thu Feb 27 15:26:18 2025 -0800

    External KV input for _update_layer_kwargs (apple#1025)

commit a3bf5e2
Author: Hanzhi Zhou <[email protected]>
Date:   Wed Feb 26 17:23:40 2025 -0800

    Minor changes to Checkpointer (apple#1024)

commit 55e1841
Author: Wentao Wu <[email protected]>
Date:   Wed Feb 26 15:45:51 2025 -0800

    Add an option to break ties for top_k_logits when k = 1 (apple#1022)

    * Add an option to support stable top_k = 1.

    * address comments

    * address comments

    * address comments

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <[email protected]>

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <[email protected]>

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <[email protected]>

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <[email protected]>

    * address comments

    ---------

    Co-authored-by: Mark Lee <[email protected]>

commit fbca3fc
Author: Meng (Ethan) Li <[email protected]>
Date:   Wed Feb 26 14:05:25 2025 -0800

    Add priority_class as a launch flag (apple#1020)

commit b26bd74
Author: Meng (Ethan) Li <[email protected]>
Date:   Wed Feb 26 14:04:47 2025 -0800

    Fix TypeError in calcualte_goodput.py (apple#1023)

commit f8191e1
Author: Dongseong Hwang <[email protected]>
Date:   Wed Feb 26 11:03:44 2025 -0800

    Emulate flash attentnion unittests on CPU. (apple#1021)

    utils.py codebase is not well covered by CI because it branches different
    backend.

    This PR introduces new CPU test, utils_test.py.
    This test is expected to run on CPU and is designed to validate GPU/TPU code
    from a CPU environment by fake mesh.
    It allows quick verification in CI and local environments to ensure that code
    changes do not break GPU/TPU Flash Attention.

commit daec8c5
Author: Chang Liu <[email protected]>
Date:   Tue Feb 25 12:38:43 2025 -0800

    Add additional_network and additional_subnetwork config to support multi-nic for v6e (apple#1019)

    Co-authored-by: Chang Liu <[email protected]>

commit ac642ea
Author: Dongseong Hwang <[email protected]>
Date:   Tue Feb 25 12:02:57 2025 -0800

    Fix crash in log-mel frontend when waveform samples are integers. (apple#1017)

    After updating JAX, this existing hidden bug started causing CI failures.
    When the sample dtype is int32 (which is valid), `jnp.finfo` returns None,
    even though `jnp.iinfo` is available.
    The previous JAX version seemed to handle this case more forgivingly.

    ```
    ../axlearn/axlearn/audio/frontend_utils.py:297: in linear_to_log_spectrogram
        return jnp.log(jnp.maximum(x, jnp.finfo(x.dtype).tiny))
    ```

commit 7c64b55
Author: Meng (Ethan) Li <[email protected]>
Date:   Tue Feb 25 10:52:50 2025 -0800

    Add LoadBalancer to GKE replicatedJob (apple#1015)

    Co-authored-by: Liang (SPG) He <[email protected]>

commit 8e8a41b
Author: Chang Lan <[email protected]>
Date:   Tue Feb 25 10:26:59 2025 -0800

    Expose jax.lax.scan's unroll option to Repeat layer (apple#1016)

    * Expose jax.lax.scan's unroll option to Repeat layer.

    * Defaults to None to avoid golden config changes

commit 682bce6
Author: Dongseong Hwang <[email protected]>
Date:   Tue Feb 25 10:09:41 2025 -0800

    Handle None bias in BiasAndResidual (apple#1018)

commit f053318
Author: Ruoming Pang <[email protected]>
Date:   Mon Feb 24 11:12:23 2025 -0500

    Allows a required value in a config_for_{function,class} to be specified via **kwargs in instantiate(). (apple#1013)

    * Allows a required value in a ClassConfigBase to be specified via **kwargs in instantiate().

    * Allows a required value in a FunctionConfigBase to be specified via **kwargs in instantiate().

commit a93cd1b
Author: Luzy <[email protected]>
Date:   Sat Feb 22 19:55:01 2025 -0500

    fix dtype in frontend pre emphasis (apple#1014)

commit c1fe2e9
Author: Maggie Zhang <[email protected]>
Date:   Fri Feb 21 19:07:39 2025 -0800

    GoodPut minor fix: only process 0 should start goodput uploader (apple#984)

    * only process 0 will start goodput uploader

    * Add unit test

commit 4b1fbf0
Author: Chang Lan <[email protected]>
Date:   Fri Feb 21 13:36:45 2025 -0800

    Async context invocation for checkpointing (apple#1012)

    * Async context invocation supoprt for checkpointer

    * Add comment

    * Add comments

commit d4cd158
Author: Ruoming Pang <[email protected]>
Date:   Fri Feb 21 14:22:24 2025 -0500

    Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase. (apple#1011)

    * Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase.

    This makes it easier for `config_for_function` and `config_for_class` to be used for functions and classes that take args of types not allowed by Config fields, e.g., Tensor.

    * Fixes pytype.

    * Addresses review.

commit 2ae6e66
Author: Meng (Ethan) Li <[email protected]>
Date:   Fri Feb 21 09:22:35 2025 -0800

    Enable megascale abort on hang or error (apple#1010)

    * Enable megascale_error_reporter_abort on hang and error by default

    * Increase threshold to 10m

commit ce4b2fb
Author: Chunyang Wen <[email protected]>
Date:   Fri Feb 21 23:12:38 2025 +0800

    Add GPU monitor (apple#1006)

commit baf8ad7
Author: Dongseong Hwang <[email protected]>
Date:   Thu Feb 20 19:35:07 2025 -0800

    Clarify setting sliding_window_size = 8 results in a window size of 9, including itself. (apple#1009)

commit cf41112
Author: Hanzhi Zhou <[email protected]>
Date:   Thu Feb 20 16:29:13 2025 -0800

    Partially reverts "gRPC Checkpointer (apple#1005)" (apple#1008)

    * Revert "gRPC Checkpointer (apple#1005)"

    This reverts commit d27c562.

    * Keep some changes

commit 454bdba
Author: Matthew Hopkins <[email protected]>
Date:   Thu Feb 20 15:10:51 2025 -0800

    upgrade jax 0.4.38 (apple#1007)

commit d27c562
Author: Hanzhi Zhou <[email protected]>
Date:   Tue Feb 18 18:54:53 2025 -0800

    gRPC Checkpointer (apple#1005)

commit fb90620
Author: Ruoming Pang <[email protected]>
Date:   Tue Feb 18 21:08:38 2025 -0500

    Makes file_system.glob support multiple patterns. (apple#1003)

    * Makes file_system.glob support multiple patterns.

    * Makes file_system.glob support multiple patterns.

    * Makes file_system.glob support multiple patterns.

    * Makes file_system.glob support multiple patterns.

commit 334f421
Author: Mark Lee <[email protected]>
Date:   Tue Feb 18 17:03:39 2025 -0800

    Reverts sliding window attention changes. (apple#1004)

    * Revert "Fix flash decoding in GPU. (apple#999)"

    This reverts commit fdadfd8.

    * Revert "Supports TPU context parallel training (apple#981)"

    This reverts commit e151d69.

    * Revert "Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995)"

    This reverts commit 67645d0.

    * Retain model/decoder asr changes.

commit 3dacc6b
Author: Chang Lan <[email protected]>
Date:   Mon Feb 17 19:45:00 2025 -0800

    Refactor aot_compilation for reuse (apple#1000)

commit c44fe18
Author: Ruoming Pang <[email protected]>
Date:   Mon Feb 17 21:38:17 2025 -0500

    Makes checkpointer_test.py use file_system. (apple#1001)

commit fdadfd8
Author: Dongseong Hwang <[email protected]>
Date:   Mon Feb 17 18:23:21 2025 -0800

    Fix flash decoding in GPU. (apple#999)

    target_positions used to be time_step, but after PR apple#995, it now represents the
    actual target positions with shape [batch, step_len].
    apple#995

    Updating the GPU decoding code to align with this change.

    CI did not cover GPU unit tests.

    TEST=test_extend_step10 of axlearn/common/flash_attention/layer_test.py in GPU

commit 9e64388
Author: Ruoming Pang <[email protected]>
Date:   Mon Feb 17 16:40:03 2025 -0500

    Makes axlearn/cloud/ use file_system. (apple#998)

    * Makes bastion.py use file_system. This is a first step towards removing the tf.io.gfile dependency.

    * Adds testing for file_system.readfile.

    * Fixes pytype.

    * Makes axlearn/cloud use file_system instead of gfile.

commit 5fba4ce
Author: Chang Lan <[email protected]>
Date:   Mon Feb 17 09:44:10 2025 -0800

    AOT compilation support for inference (apple#997)

    * Add optional `devices` init argument to InferenceRunner for passing
      fake devices during AOT compilation.

    * Add more v5e slice types.

commit e151d69
Author: Hanzhi Zhou <[email protected]>
Date:   Sun Feb 16 13:01:15 2025 -0800

    Supports TPU context parallel training (apple#981)

    Fix

    Fix tests

commit 67645d0
Author: Dongseong Hwang <[email protected]>
Date:   Sat Feb 15 13:26:51 2025 -0800

    Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995)

    * Revert "Transpose kv cache for better decode performance (apple#979)"

    This reverts commit b130416.

    * Update golden configs

    * Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding.

    Currently, when using `MultiheadAttention` or `GroupedQueryAttention` for
    sliding window attention, the KV cache is kept for the full sequence length
    (`seq_len`) instead of the window length (`window_len`).

    For example, a model with `window_len=1k` and `seq_len=2M` keeps a KV cache
    for the full 2M tokens. It then biases 1999k invalid KV tokens before
    calculating attention, resulting in a computational complexity of **O(2M²)**
    instead of the desired **O(1k²)**.

    This issue persists even when using flash attention. Flash attention uses the
    KV cache allocated in HBM as its input. While unnecessary blocks are discarded
    during computation, the KV cache still occupies HBM inefficiently for the full
    2M tokens.

    To address this, when `MultiheadAttention` detects a sliding window mask, it
    stores the key-value (KV) cache in a ring buffer inside the input linear layer.
    As a result, downstream projects using `MultiheadAttention` automatically
    benefit from efficient KV cache handling in `init_states` and `extend_step`.

    Additionally, for use cases like local-global attention in LLMs, it is
    recommended to use sliding window masks for even the global attention as well.
    For example, if you want to train an LLM with a context length of 8k, you can
    set the sliding window size to 8k during training. This enables functionally
    infinite decoding during inference. Accuracy wouldn't be good tho.

    Note:
    * query_positions in QKVLinear.forward() was introduced by
      apple#914. Now it returns to the caller.

    This PR actually moves from downstream speech/streaming/sliding_window_attention.py

    * transpose

commit 272a4d2
Author: Chang Lan <[email protected]>
Date:   Fri Feb 14 10:48:21 2025 -0800

    Add v5e-8 (apple#994)

commit debb46a
Author: Mark Lee <[email protected]>
Date:   Thu Feb 13 22:27:28 2025 -0800

    Decouples jobsets from replicated jobs. (apple#991)

    * Decouples jobsets from replicated jobs.

    * Address comments.

commit 8f2b99d
Author: Maggie Zhang <[email protected]>
Date:   Thu Feb 13 20:49:48 2025 -0800

    Add Goodput documentation (apple#989)

    * Temporarily change checkpointing to every 5 steps

    * revert local changes

    * Add example command for goodput usage

commit 31e8da0
Author: Alexander Pivovarov <[email protected]>
Date:   Thu Feb 13 15:44:51 2025 -0800

    Fix Missing return statement in base_layer_test.py::ExplicitFanLayer::_compute_fan_axes (apple#987)

commit 7f2dd9e
Author: Apoorv Gupta <[email protected]>
Date:   Thu Feb 13 14:36:11 2025 -0800

    Flash Attention for Neuron (apple#939)

commit 6ca4f56
Author: Philipp Dufter <[email protected]>
Date:   Thu Feb 13 23:09:28 2025 +0100

    pass on log_warning in input_tf_data.skip_on_error (apple#990)

    * make log_warnings customizable in tfds skip error

    * address comments

commit 1a8a0eb
Author: Hanzhi Zhou <[email protected]>
Date:   Thu Feb 13 13:32:34 2025 -0800

    Integrate Orbax's emergency checkpoint. (apple#820)

    * Integrate Orbax emergency checkpoint

    * Address comments

    * comment

    * Address comments

    * Upgrade orbax

    * Improve comments

    * Improve comments

    * Update for new orbax versions

    * Better timer

    * Address comments

    * Add step test

    * Fix

    * Add comment

commit 42fd715
Author: Apoorv Gupta <[email protected]>
Date:   Thu Feb 13 09:13:30 2025 -0800

    TRN2 Meshes and Configurations (apple#916)

    * TRN2 Meshes and Configurations

    * Add get_recursive and set_recursive to ConfigBase.

    * Use loops inside get/set_recursively

    + address comments

    * Update partition spec

    * Use get_recursively inside set

    * Move trn2 configs to a helper function.

    + Fix modifier tests

    * TRN2 partitionspec supports DP over FSDP and TP

    * Use for loop in get_recursively

    * Update Golden Configs

commit d47d5ce
Author: Haoshuo Huang <[email protected]>
Date:   Tue Feb 11 18:13:13 2025 -0800

    Add support to slice dataset based on proportions. (apple#982)

commit ed8f382
Author: Mark Lee <[email protected]>
Date:   Tue Feb 11 13:22:44 2025 -0800

    Allow metrics layers to have state. (apple#978)

    * Allow metrics layers to have state.

    * Move BaseLossMetrics to a new file.

commit b130416
Author: Chang Lan <[email protected]>
Date:   Tue Feb 11 00:01:28 2025 -0800

    Transpose kv cache for better decode performance (apple#979)

commit 48bf488
Author: Haoshuo Huang <[email protected]>
Date:   Mon Feb 10 22:25:18 2025 -0800

    Add support for grain.IterDataset in sampling (apple#980)

commit d4b563c
Author: Alexander Pivovarov <[email protected]>
Date:   Mon Feb 10 15:36:22 2025 -0800

    Replace jnp.ndarray with Tensor from axlearn.common.utils (apple#973)

commit 0666d80
Author: Alexander Pivovarov <[email protected]>
Date:   Mon Feb 10 15:35:23 2025 -0800

    Fix membership checks in tool_use_execution.py (apple#974)

commit 2f4763c
Author: Alexander Pivovarov <[email protected]>
Date:   Mon Feb 10 15:31:59 2025 -0800

    Remove redundant import logging (apple#975)

commit 58dcf33
Author: Hanzhi Zhou <[email protected]>
Date:   Mon Feb 10 13:41:33 2025 -0800

    Enable cudnn dropout (apple#913)

commit ae855ed
Author: Mark Lee <[email protected]>
Date:   Mon Feb 10 12:43:50 2025 -0800

    Ensures that cache_dtype is respected. (apple#977)

commit cfef38b
Author: Daniel Swann <[email protected]>
Date:   Mon Feb 10 10:56:10 2025 -0800

    :sparkles: Add cache for CloudBuild API location queries (apple#967)

commit 8fd9137
Author: Wei Liu <[email protected]>
Date:   Sun Feb 9 15:33:53 2025 -0800

    Add segment_ids option in DiTAttentionLayer (apple#976)

commit e55a404
Author: Chang Lan <[email protected]>
Date:   Sun Feb 9 04:38:49 2025 -0800

    Use broadcasting trick for KV update (apple#972)

    * Use vmap and dynamic_update_slice for KV update

    * Broadcasting trick

    * Simplify the impl per @markblee's suggestion

    * comments

commit b955187
Author: Dongseong Hwang <[email protected]>
Date:   Fri Feb 7 14:12:48 2025 -0800

    Don't keep initial key/value inputs in the KV cache. (apple#968)

    The current code is weird. It stores the input key/value in the KV cache, but
    this doesn’t make sense in either init_states or prefill:
    * init_states: This is not prefill, so key/value should not be stored in the KV cache.
    * prefill: The extend_step() function overrides this part anyway.

    Thus, this PR removes this unnecessary and confusing logic.
    The logic was introduced in apple#860

commit c3d656d
Author: zhengdong-zhang <[email protected]>
Date:   Fri Feb 7 10:18:42 2025 -0800

    Refactorization. (apple#963)

commit 1c883d8
Author: Zhao Xu <[email protected]>
Date:   Fri Feb 7 10:02:56 2025 -0800

    Support system role when calling the Gemini API. (apple#971)

commit ceab4f4
Author: Haoshuo Huang <[email protected]>
Date:   Thu Feb 6 20:41:07 2025 -0800

    Making shared_memory configurable (apple#969)

    * Making shared_memory configurable

    * fix eol space

commit 323faa3
Author: Meng (Ethan) Li <[email protected]>
Date:   Thu Feb 6 12:11:28 2025 -0800

    Use env id for gcp settings (apple#957)

    * Use env_id to replace zone as gcp_settings key to support multiple env under the same zone

    * fall back to zone

    * address comments

    * Suppport project in the label filter; always get zone from gcp_setting value instead of return it directly

commit 2ec3a02
Author: Chang Lan <[email protected]>
Date:   Wed Feb 5 22:25:58 2025 -0800

    Fix incorrect number of formatting arguments (apple#966)

commit d131d3b
Author: Nan Du <[email protected]>
Date:   Mon Feb 3 11:44:00 2025 -0800

    Reduce the verbosity of variable norm summaries (apple#965)

commit c1c6e29
Author: Kelvin Zou <[email protected]>
Date:   Fri Jan 31 22:24:39 2025 -0800

    Sliding window support for GPU flash attention (apple#962)

    * snapshot

    * snapshot

    * snapshot

    * remove unexpected change

    * adding shape commenbt

    * fix pylint

    * snapshot

commit 0936a17
Author: Mark Lee <[email protected]>
Date:   Fri Jan 31 13:59:12 2025 -0800

    Supports loss_weights and live_targets in metrics. (apple#960)

    * Supports loss_weights, live_targets, and module sharing in metrics.

    * Addresses comments.

    * Explicitly test flatten_metrics=True.

commit 7a40f91
Author: Dipannita Shaw <[email protected]>
Date:   Fri Jan 31 11:45:33 2025 -0800

    Add Goodput & Badput recording and monitoring support. (apple#783)

    * Code clean up

    * Add more testing

    * Fix docstrings

    * Remove recorder calls from trainer for now

    * Code cleanup gcp/measurement.py

    Co-authored-by: Ruoming Pang <[email protected]>

    * Code cleanup  common/measurement.py

    Co-authored-by: Ruoming Pang <[email protected]>

    * Fix pre commit errors

    * Adding more tests

    * Further clean up

    * Fix a test error

    ---------

    Co-authored-by: Ruoming Pang <[email protected]>

commit 031a7f3
Author: Mark Lee <[email protected]>
Date:   Thu Jan 30 20:19:12 2025 -0800

    Skipping empty grain batches during unbatch. (apple#961)

    * Skipping empty grain batches during unbatch.

    * Use a loop instead of recursion.

commit 795da33
Author: Hanzhi Zhou <[email protected]>
Date:   Thu Jan 30 07:17:16 2025 -0800

    Optimizer offloading through weight-only offload (apple#867)

    * Optimizer offloading

    * Style fix

    * Type fix

commit b1a1a5a
Author: Haoshuo Huang <[email protected]>
Date:   Wed Jan 29 21:44:15 2025 -0800

    Improve gcsfuse io (apple#959)

commit d76ef6f
Author: Hanzhi Zhou <[email protected]>
Date:   Wed Jan 29 15:10:13 2025 -0800

    SplashAttention performance tuning for v6e (apple#958)

    * SplashAttention tuning for v6e

    * Add import to fix pytype errors

commit 2d002e3
Author: Hanzhi Zhou <[email protected]>
Date:   Wed Jan 29 12:07:56 2025 -0800

    Use InputDispatcher for fuji models (apple#956)

    * Use dispatcher

    * Update golden configs

    * Remove logical feed indices

commit fad264b
Author: Mark Lee <[email protected]>
Date:   Tue Jan 28 10:41:54 2025 -0800

    Explicitly pass module outputs to metrics. (apple#953)

    * Explicitly pass module outputs to metrics.

    * Support and add checks for module/state updates.

    * Only flatten summaries.

commit 59508e3
Author: Hanzhi Zhou <[email protected]>
Date:   Tue Jan 28 10:34:52 2025 -0800

    Add v6e PCIe overload workaround flag (apple#955)

commit 028ecfd
Author: Haoshuo Huang <[email protected]>
Date:   Mon Jan 27 20:54:28 2025 -0800

    Fix GCSFUSE flags by setting resource limit. (apple#954)

commit 3e2c6dd
Author: Matthew Hopkins <[email protected]>
Date:   Mon Jan 27 14:56:42 2025 -0800

    update jax to 0.4.37 (apple#948)

    update BlockSpec usage in tpu_attention
    use TYPE_CHECKING for BuildDatasetFn in input_fake
    add todo for BuildDatasetFn

commit b125f00
Author: Hanzhi Zhou <[email protected]>
Date:   Mon Jan 27 11:29:23 2025 -0800

    Add v6e special meshes (apple#952)

    * Add v6e special mesh

    * Add v6e special mesh

    * Fix

    * Fix

commit a854738
Author: Firenze11 <[email protected]>
Date:   Mon Jan 27 09:17:46 2025 -0800

    Allow external positions to be inputed in RoPE embedding layer (apple#926)

    * Allow external positions to be inputed in RoPE embedding layer

    Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after `i_proj`. Unlike the implementation of current `RoFormerQKVLinear`, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by `MultiheadAttention`.

    * Update attention_test.py

    * Update dit.py

    * Update attention.py

    * Update attention_test.py

    * Update attention.py

    * Update dit.py

    * Update axlearn/common/attention.py

    Co-authored-by: Mark Lee <[email protected]>

    * respond to comments.

    Co-authored-by: Ruoming Pang <[email protected]>

    * Update attention.py

    * Update attention.py

    * Update attention.py

    ---------

    Co-authored-by: Mark Lee <[email protected]>
    Co-authored-by: Ruoming Pang <[email protected]>

commit 999401a
Author: qdavid1 <[email protected]>
Date:   Mon Jan 27 09:11:17 2025 -0800

    Update LoraFusedQKVLinear (apple#949)

commit 1c22688
Author: Mark Lee <[email protected]>
Date:   Sun Jan 26 04:51:02 2025 -0800

    Workaround module outputs being dropped. (apple#951)

commit 94c81cb
Author: Meng (Ethan) Li <[email protected]>
Date:   Fri Jan 24 11:01:45 2025 -0800

    Add link to github issue regarding kubernetes-32.0.0 (apple#947)

commit a6e0f4a
Author: Meng (Ethan) Li <[email protected]>
Date:   Fri Jan 24 08:40:25 2025 -0800

    Pin kubernetes pip version to 31.0.0 to fix client authentication error (apple#946)

commit 076521a
Author: Mark Lee <[email protected]>
Date:   Thu Jan 23 15:11:00 2025 -0800

    Forward input keys to decoder. (apple#944)

commit 30284c8
Author: Hanzhi Zhou <[email protected]>
Date:   Thu Jan 23 10:33:54 2025 -0800

    Legacy flash remat fix (apple#943)

    * Fix the same problem for legacy tpu attn

    * Fix

commit 6a9f980
Author: Mark Lee <[email protected]>
Date:   Thu Jan 23 09:20:46 2025 -0800

    Adds mesh rule for a3-megagpu-8g. (apple#936)

commit ac7a3ed
Author: Dongseong Hwang <[email protected]>
Date:   Thu Jan 23 08:15:27 2025 -0800

    Enabled running Pallas Flash Attention on CPU. (apple#922)

    Pallas supports CPU simulation (`interpret=True`), so we can use the same
    TPU Pallas kernel on CPU — making code debugging easier.

    This change lets the following unittests run on CPU as if they were on TPU,
    enabling easier testing and debugging:
    - `axlearn/common/flash_attention/tpu_attention_test.py`

    Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU.
    - `axlearn/common/flash_attention/gpu_attention_test.py`

    Now CI covers those tests on CPU as well.
    In M3 Max MacBook Pro, test coverages and processing time are as follows,
    * axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20)
    * axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s

commit 8ea85bd
Author: Hanzhi Zhou <[email protected]>
Date:   Wed Jan 22 09:51:15 2025 -0800

    Some fixes for flash remat (apple#942)

commit 185b1b5
Author: Chang Lan <[email protected]>
Date:   Tue Jan 21 11:21:08 2025 -0800

    Repeat KV heads in Flash Attention (apple#938)

    * Roll back '_repeat_kv_heads' change in Flash Attention

    Recent PR removed _repeat_kv_heads from Flash Attention for GQA optimization,
    in the hope to reduce HBM usage. However the actual HBM saving would be limited
    in the model-parallel setting, as the heads are already sharded across devices.
    It also introduces some limitation which breaks some of the existing sharding
    configurations.

    For example, let's say num_heads = 8 and num_kv_heads = 4. When we repeat KV heads,
    we can set the model axis as 8 so that each device will have only one Q, K, V head;
    Without repeat_kv_heads, the max value of model axis is 4, and each device will have
    2 Q heads as a result, increasing the actual HBM usage.

    * Repeat kv as necessary for sharding

    * Unit tests

    * Address comments.

commit 4678740
Author: Chang Lan <[email protected]>
Date:   Mon Jan 20 20:36:44 2025 -0800

    AOT compilation for v6e (apple#937)

commit 357bef6
Author: Mark Lee <[email protected]>
Date:   Mon Jan 20 20:23:39 2025 -0800

    Makes causal lm metrics configurable. (apple#934)

    * Makes causal lm metrics configurable.

    * Address review comments.

    * Make metrics required.

    * Update golden configs.

    * Removes PredictModel.

commit 16ca0c2
Author: Mark Lee <[email protected]>
Date:   Sun Jan 19 14:19:20 2025 -0800

    Supports flexible input partition specs. (apple#933)

    * Supports flexible input partition specs in causal lm.

    * Moves the input partitioning to Input.

    * Adds missing pytest marker.

    * Address review comments.

    * Rebase and update golden configs.

    * Fixes batch axis names and adds a test.

commit 9b75ef1
Author: Mark Lee <[email protected]>
Date:   Sun Jan 19 07:43:19 2025 -0800

    Avoid a top-level import of tokenizers. (apple#935)

commit 9996f34
Author: sychen52 <[email protected]>
Date:   Sat Jan 18 09:44:04 2025 -0800

    Add llama 3 tokenizer (apple#850)

    * Add llama 3 tokenizer

    add a new version called V3_TIKTOKEN.

    other edits based on suggestions.

    * Handle special tokens like other vocabularies.

    * use encode instead of encode_batch

commit ad14de3
Author: Haoshuo Huang <[email protected]>
Date:   Fri Jan 17 14:19:24 2025 -0800

    Add ReadOptions args to _make_autoregressive_inputs (apple#931)

    * Add ReadOptions args to _make_autoregressive_inputs

    * use read_options as args instead

commit 4858070
Author: Sam Stoelinga <[email protected]>
Date:   Fri Jan 17 13:54:05 2025 -0800

    improve GCS perf: Change resource limit to request (apple#851)

commit b0ee05e
Author: Bailin <[email protected]>
Date:   Fri Jan 17 22:53:00 2025 +0800

    Add Mamab2 and its Jamba variant (apple#839)

    * add mamab2

    * merge

    * unify init and prefill

    * adapt final changes

    ---------

    Co-authored-by: bailin_wang <[email protected]>

commit 1e25e4a
Author: Hanzhi Zhou <[email protected]>
Date:   Thu Jan 16 11:25:24 2025 -0800

    Cache AoT compilation result (apple#927)

    * Cache AoT compilation result

    * Fix comments

    * Fix

    * Fix

    * Fix

    * Fix
rahul003 added a commit to rahul003/axlearn that referenced this pull request Mar 5, 2025
commit 336c75d
Author: Mark Lee <[email protected]>
Date:   Mon Mar 3 09:04:07 2025 -0800

    Supports arbitrary uniform partitioning in host-global array conversions. (apple#1029)

    * Allows specifying PartitionSpec to host_to_global_device_array.

    * Generalizes to arbitrary uniform partitioning.

    * Addresses comments and adds mixed shape test.

commit 0881412
Author: Dongseong Hwang <[email protected]>
Date:   Sat Mar 1 15:41:38 2025 -0800

    Refactor Mask in Attention (apple#1028)

    Currently, the attention code is **hardcoded** to handle either `causal_mask`
    or an arbitrary `mask_fn`.

    To support **sliding window masks**, we previously used a **hack** by injecting
    the `_sliding_window_size` attribute into functions.

    This refactor **makes the masking logic more flexible** by allowing arbitrary
    `MaskFnAttentionBias`.
    - If downstream requires a **new mask pattern**, they can simply:
      1. Implement a **subclass of `MaskFnAttentionBias`**.
      2. Set `attention.mask` accordingly.

commit f67d3f9
Author: Dongseong Hwang <[email protected]>
Date:   Fri Feb 28 08:53:00 2025 -0800

    Flash Attention now explicitly checks whether it is in decoding mode. (apple#1026)

    Currently, Flash Attention infers decoding implicitly based on circumstantial
    evidence. This PR makes the check explicit.

commit f8d2c66
Author: qdavid1 <[email protected]>
Date:   Thu Feb 27 15:26:18 2025 -0800

    External KV input for _update_layer_kwargs (apple#1025)

commit a3bf5e2
Author: Hanzhi Zhou <[email protected]>
Date:   Wed Feb 26 17:23:40 2025 -0800

    Minor changes to Checkpointer (apple#1024)

commit 55e1841
Author: Wentao Wu <[email protected]>
Date:   Wed Feb 26 15:45:51 2025 -0800

    Add an option to break ties for top_k_logits when k = 1 (apple#1022)

    * Add an option to support stable top_k = 1.

    * address comments

    * address comments

    * address comments

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <[email protected]>

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <[email protected]>

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <[email protected]>

    * Update axlearn/common/logit_modifiers.py

    Co-authored-by: Mark Lee <[email protected]>

    * address comments

    ---------

    Co-authored-by: Mark Lee <[email protected]>

commit fbca3fc
Author: Meng (Ethan) Li <[email protected]>
Date:   Wed Feb 26 14:05:25 2025 -0800

    Add priority_class as a launch flag (apple#1020)

commit b26bd74
Author: Meng (Ethan) Li <[email protected]>
Date:   Wed Feb 26 14:04:47 2025 -0800

    Fix TypeError in calcualte_goodput.py (apple#1023)

commit f8191e1
Author: Dongseong Hwang <[email protected]>
Date:   Wed Feb 26 11:03:44 2025 -0800

    Emulate flash attentnion unittests on CPU. (apple#1021)

    utils.py codebase is not well covered by CI because it branches different
    backend.

    This PR introduces new CPU test, utils_test.py.
    This test is expected to run on CPU and is designed to validate GPU/TPU code
    from a CPU environment by fake mesh.
    It allows quick verification in CI and local environments to ensure that code
    changes do not break GPU/TPU Flash Attention.

commit daec8c5
Author: Chang Liu <[email protected]>
Date:   Tue Feb 25 12:38:43 2025 -0800

    Add additional_network and additional_subnetwork config to support multi-nic for v6e (apple#1019)

    Co-authored-by: Chang Liu <[email protected]>

commit ac642ea
Author: Dongseong Hwang <[email protected]>
Date:   Tue Feb 25 12:02:57 2025 -0800

    Fix crash in log-mel frontend when waveform samples are integers. (apple#1017)

    After updating JAX, this existing hidden bug started causing CI failures.
    When the sample dtype is int32 (which is valid), `jnp.finfo` returns None,
    even though `jnp.iinfo` is available.
    The previous JAX version seemed to handle this case more forgivingly.

    ```
    ../axlearn/axlearn/audio/frontend_utils.py:297: in linear_to_log_spectrogram
        return jnp.log(jnp.maximum(x, jnp.finfo(x.dtype).tiny))
    ```

commit 7c64b55
Author: Meng (Ethan) Li <[email protected]>
Date:   Tue Feb 25 10:52:50 2025 -0800

    Add LoadBalancer to GKE replicatedJob (apple#1015)

    Co-authored-by: Liang (SPG) He <[email protected]>

commit 8e8a41b
Author: Chang Lan <[email protected]>
Date:   Tue Feb 25 10:26:59 2025 -0800

    Expose jax.lax.scan's unroll option to Repeat layer (apple#1016)

    * Expose jax.lax.scan's unroll option to Repeat layer.

    * Defaults to None to avoid golden config changes

commit 682bce6
Author: Dongseong Hwang <[email protected]>
Date:   Tue Feb 25 10:09:41 2025 -0800

    Handle None bias in BiasAndResidual (apple#1018)

commit f053318
Author: Ruoming Pang <[email protected]>
Date:   Mon Feb 24 11:12:23 2025 -0500

    Allows a required value in a config_for_{function,class} to be specified via **kwargs in instantiate(). (apple#1013)

    * Allows a required value in a ClassConfigBase to be specified via **kwargs in instantiate().

    * Allows a required value in a FunctionConfigBase to be specified via **kwargs in instantiate().

commit a93cd1b
Author: Luzy <[email protected]>
Date:   Sat Feb 22 19:55:01 2025 -0500

    fix dtype in frontend pre emphasis (apple#1014)

commit c1fe2e9
Author: Maggie Zhang <[email protected]>
Date:   Fri Feb 21 19:07:39 2025 -0800

    GoodPut minor fix: only process 0 should start goodput uploader (apple#984)

    * only process 0 will start goodput uploader

    * Add unit test

commit 4b1fbf0
Author: Chang Lan <[email protected]>
Date:   Fri Feb 21 13:36:45 2025 -0800

    Async context invocation for checkpointing (apple#1012)

    * Async context invocation supoprt for checkpointer

    * Add comment

    * Add comments

commit d4cd158
Author: Ruoming Pang <[email protected]>
Date:   Fri Feb 21 14:22:24 2025 -0500

    Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase. (apple#1011)

    * Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase.

    This makes it easier for `config_for_function` and `config_for_class` to be used for functions and classes that take args of types not allowed by Config fields, e.g., Tensor.

    * Fixes pytype.

    * Addresses review.

commit 2ae6e66
Author: Meng (Ethan) Li <[email protected]>
Date:   Fri Feb 21 09:22:35 2025 -0800

    Enable megascale abort on hang or error (apple#1010)

    * Enable megascale_error_reporter_abort on hang and error by default

    * Increase threshold to 10m

commit ce4b2fb
Author: Chunyang Wen <[email protected]>
Date:   Fri Feb 21 23:12:38 2025 +0800

    Add GPU monitor (apple#1006)

commit baf8ad7
Author: Dongseong Hwang <[email protected]>
Date:   Thu Feb 20 19:35:07 2025 -0800

    Clarify setting sliding_window_size = 8 results in a window size of 9, including itself. (apple#1009)

commit cf41112
Author: Hanzhi Zhou <[email protected]>
Date:   Thu Feb 20 16:29:13 2025 -0800

    Partially reverts "gRPC Checkpointer (apple#1005)" (apple#1008)

    * Revert "gRPC Checkpointer (apple#1005)"

    This reverts commit d27c562.

    * Keep some changes

commit 454bdba
Author: Matthew Hopkins <[email protected]>
Date:   Thu Feb 20 15:10:51 2025 -0800

    upgrade jax 0.4.38 (apple#1007)

commit d27c562
Author: Hanzhi Zhou <[email protected]>
Date:   Tue Feb 18 18:54:53 2025 -0800

    gRPC Checkpointer (apple#1005)

commit fb90620
Author: Ruoming Pang <[email protected]>
Date:   Tue Feb 18 21:08:38 2025 -0500

    Makes file_system.glob support multiple patterns. (apple#1003)

    * Makes file_system.glob support multiple patterns.

    * Makes file_system.glob support multiple patterns.

    * Makes file_system.glob support multiple patterns.

    * Makes file_system.glob support multiple patterns.

commit 334f421
Author: Mark Lee <[email protected]>
Date:   Tue Feb 18 17:03:39 2025 -0800

    Reverts sliding window attention changes. (apple#1004)

    * Revert "Fix flash decoding in GPU. (apple#999)"

    This reverts commit fdadfd8.

    * Revert "Supports TPU context parallel training (apple#981)"

    This reverts commit e151d69.

    * Revert "Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995)"

    This reverts commit 67645d0.

    * Retain model/decoder asr changes.

commit 3dacc6b
Author: Chang Lan <[email protected]>
Date:   Mon Feb 17 19:45:00 2025 -0800

    Refactor aot_compilation for reuse (apple#1000)

commit c44fe18
Author: Ruoming Pang <[email protected]>
Date:   Mon Feb 17 21:38:17 2025 -0500

    Makes checkpointer_test.py use file_system. (apple#1001)

commit fdadfd8
Author: Dongseong Hwang <[email protected]>
Date:   Mon Feb 17 18:23:21 2025 -0800

    Fix flash decoding in GPU. (apple#999)

    target_positions used to be time_step, but after PR apple#995, it now represents the
    actual target positions with shape [batch, step_len].
    apple#995

    Updating the GPU decoding code to align with this change.

    CI did not cover GPU unit tests.

    TEST=test_extend_step10 of axlearn/common/flash_attention/layer_test.py in GPU

commit 9e64388
Author: Ruoming Pang <[email protected]>
Date:   Mon Feb 17 16:40:03 2025 -0500

    Makes axlearn/cloud/ use file_system. (apple#998)

    * Makes bastion.py use file_system. This is a first step towards removing the tf.io.gfile dependency.

    * Adds testing for file_system.readfile.

    * Fixes pytype.

    * Makes axlearn/cloud use file_system instead of gfile.

commit 5fba4ce
Author: Chang Lan <[email protected]>
Date:   Mon Feb 17 09:44:10 2025 -0800

    AOT compilation support for inference (apple#997)

    * Add optional `devices` init argument to InferenceRunner for passing
      fake devices during AOT compilation.

    * Add more v5e slice types.

commit e151d69
Author: Hanzhi Zhou <[email protected]>
Date:   Sun Feb 16 13:01:15 2025 -0800

    Supports TPU context parallel training (apple#981)

    Fix

    Fix tests

commit 67645d0
Author: Dongseong Hwang <[email protected]>
Date:   Sat Feb 15 13:26:51 2025 -0800

    Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995)

    * Revert "Transpose kv cache for better decode performance (apple#979)"

    This reverts commit b130416.

    * Update golden configs

    * Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding.

    Currently, when using `MultiheadAttention` or `GroupedQueryAttention` for
    sliding window attention, the KV cache is kept for the full sequence length
    (`seq_len`) instead of the window length (`window_len`).

    For example, a model with `window_len=1k` and `seq_len=2M` keeps a KV cache
    for the full 2M tokens. It then biases 1999k invalid KV tokens before
    calculating attention, resulting in a computational complexity of **O(2M²)**
    instead of the desired **O(1k²)**.

    This issue persists even when using flash attention. Flash attention uses the
    KV cache allocated in HBM as its input. While unnecessary blocks are discarded
    during computation, the KV cache still occupies HBM inefficiently for the full
    2M tokens.

    To address this, when `MultiheadAttention` detects a sliding window mask, it
    stores the key-value (KV) cache in a ring buffer inside the input linear layer.
    As a result, downstream projects using `MultiheadAttention` automatically
    benefit from efficient KV cache handling in `init_states` and `extend_step`.

    Additionally, for use cases like local-global attention in LLMs, it is
    recommended to use sliding window masks for even the global attention as well.
    For example, if you want to train an LLM with a context length of 8k, you can
    set the sliding window size to 8k during training. This enables functionally
    infinite decoding during inference. Accuracy wouldn't be good tho.

    Note:
    * query_positions in QKVLinear.forward() was introduced by
      apple#914. Now it returns to the caller.

    This PR actually moves from downstream speech/streaming/sliding_window_attention.py

    * transpose

commit 272a4d2
Author: Chang Lan <[email protected]>
Date:   Fri Feb 14 10:48:21 2025 -0800

    Add v5e-8 (apple#994)

commit debb46a
Author: Mark Lee <[email protected]>
Date:   Thu Feb 13 22:27:28 2025 -0800

    Decouples jobsets from replicated jobs. (apple#991)

    * Decouples jobsets from replicated jobs.

    * Address comments.

commit 8f2b99d
Author: Maggie Zhang <[email protected]>
Date:   Thu Feb 13 20:49:48 2025 -0800

    Add Goodput documentation (apple#989)

    * Temporarily change checkpointing to every 5 steps

    * revert local changes

    * Add example command for goodput usage

commit 31e8da0
Author: Alexander Pivovarov <[email protected]>
Date:   Thu Feb 13 15:44:51 2025 -0800

    Fix Missing return statement in base_layer_test.py::ExplicitFanLayer::_compute_fan_axes (apple#987)

commit 7f2dd9e
Author: Apoorv Gupta <[email protected]>
Date:   Thu Feb 13 14:36:11 2025 -0800

    Flash Attention for Neuron (apple#939)

commit 6ca4f56
Author: Philipp Dufter <[email protected]>
Date:   Thu Feb 13 23:09:28 2025 +0100

    pass on log_warning in input_tf_data.skip_on_error (apple#990)

    * make log_warnings customizable in tfds skip error

    * address comments

commit 1a8a0eb
Author: Hanzhi Zhou <[email protected]>
Date:   Thu Feb 13 13:32:34 2025 -0800

    Integrate Orbax's emergency checkpoint. (apple#820)

    * Integrate Orbax emergency checkpoint

    * Address comments

    * comment

    * Address comments

    * Upgrade orbax

    * Improve comments

    * Improve comments

    * Update for new orbax versions

    * Better timer

    * Address comments

    * Add step test

    * Fix

    * Add comment

commit 42fd715
Author: Apoorv Gupta <[email protected]>
Date:   Thu Feb 13 09:13:30 2025 -0800

    TRN2 Meshes and Configurations (apple#916)

    * TRN2 Meshes and Configurations

    * Add get_recursive and set_recursive to ConfigBase.

    * Use loops inside get/set_recursively

    + address comments

    * Update partition spec

    * Use get_recursively inside set

    * Move trn2 configs to a helper function.

    + Fix modifier tests

    * TRN2 partitionspec supports DP over FSDP and TP

    * Use for loop in get_recursively

    * Update Golden Configs

commit d47d5ce
Author: Haoshuo Huang <[email protected]>
Date:   Tue Feb 11 18:13:13 2025 -0800

    Add support to slice dataset based on proportions. (apple#982)

commit ed8f382
Author: Mark Lee <[email protected]>
Date:   Tue Feb 11 13:22:44 2025 -0800

    Allow metrics layers to have state. (apple#978)

    * Allow metrics layers to have state.

    * Move BaseLossMetrics to a new file.

commit b130416
Author: Chang Lan <[email protected]>
Date:   Tue Feb 11 00:01:28 2025 -0800

    Transpose kv cache for better decode performance (apple#979)

commit 48bf488
Author: Haoshuo Huang <[email protected]>
Date:   Mon Feb 10 22:25:18 2025 -0800

    Add support for grain.IterDataset in sampling (apple#980)

commit d4b563c
Author: Alexander Pivovarov <[email protected]>
Date:   Mon Feb 10 15:36:22 2025 -0800

    Replace jnp.ndarray with Tensor from axlearn.common.utils (apple#973)

commit 0666d80
Author: Alexander Pivovarov <[email protected]>
Date:   Mon Feb 10 15:35:23 2025 -0800

    Fix membership checks in tool_use_execution.py (apple#974)

commit 2f4763c
Author: Alexander Pivovarov <[email protected]>
Date:   Mon Feb 10 15:31:59 2025 -0800

    Remove redundant import logging (apple#975)

commit 58dcf33
Author: Hanzhi Zhou <[email protected]>
Date:   Mon Feb 10 13:41:33 2025 -0800

    Enable cudnn dropout (apple#913)

commit ae855ed
Author: Mark Lee <[email protected]>
Date:   Mon Feb 10 12:43:50 2025 -0800

    Ensures that cache_dtype is respected. (apple#977)

commit cfef38b
Author: Daniel Swann <[email protected]>
Date:   Mon Feb 10 10:56:10 2025 -0800

    :sparkles: Add cache for CloudBuild API location queries (apple#967)

commit 8fd9137
Author: Wei Liu <[email protected]>
Date:   Sun Feb 9 15:33:53 2025 -0800

    Add segment_ids option in DiTAttentionLayer (apple#976)

commit e55a404
Author: Chang Lan <[email protected]>
Date:   Sun Feb 9 04:38:49 2025 -0800

    Use broadcasting trick for KV update (apple#972)

    * Use vmap and dynamic_update_slice for KV update

    * Broadcasting trick

    * Simplify the impl per @markblee's suggestion

    * comments

commit b955187
Author: Dongseong Hwang <[email protected]>
Date:   Fri Feb 7 14:12:48 2025 -0800

    Don't keep initial key/value inputs in the KV cache. (apple#968)

    The current code is weird. It stores the input key/value in the KV cache, but
    this doesn’t make sense in either init_states or prefill:
    * init_states: This is not prefill, so key/value should not be stored in the KV cache.
    * prefill: The extend_step() function overrides this part anyway.

    Thus, this PR removes this unnecessary and confusing logic.
    The logic was introduced in apple#860

commit c3d656d
Author: zhengdong-zhang <[email protected]>
Date:   Fri Feb 7 10:18:42 2025 -0800

    Refactorization. (apple#963)

commit 1c883d8
Author: Zhao Xu <[email protected]>
Date:   Fri Feb 7 10:02:56 2025 -0800

    Support system role when calling the Gemini API. (apple#971)

commit ceab4f4
Author: Haoshuo Huang <[email protected]>
Date:   Thu Feb 6 20:41:07 2025 -0800

    Making shared_memory configurable (apple#969)

    * Making shared_memory configurable

    * fix eol space

commit 323faa3
Author: Meng (Ethan) Li <[email protected]>
Date:   Thu Feb 6 12:11:28 2025 -0800

    Use env id for gcp settings (apple#957)

    * Use env_id to replace zone as gcp_settings key to support multiple env under the same zone

    * fall back to zone

    * address comments

    * Suppport project in the label filter; always get zone from gcp_setting value instead of return it directly

commit 2ec3a02
Author: Chang Lan <[email protected]>
Date:   Wed Feb 5 22:25:58 2025 -0800

    Fix incorrect number of formatting arguments (apple#966)

commit d131d3b
Author: Nan Du <[email protected]>
Date:   Mon Feb 3 11:44:00 2025 -0800

    Reduce the verbosity of variable norm summaries (apple#965)

commit c1c6e29
Author: Kelvin Zou <[email protected]>
Date:   Fri Jan 31 22:24:39 2025 -0800

    Sliding window support for GPU flash attention (apple#962)

    * snapshot

    * snapshot

    * snapshot

    * remove unexpected change

    * adding shape commenbt

    * fix pylint

    * snapshot

commit 0936a17
Author: Mark Lee <[email protected]>
Date:   Fri Jan 31 13:59:12 2025 -0800

    Supports loss_weights and live_targets in metrics. (apple#960)

    * Supports loss_weights, live_targets, and module sharing in metrics.

    * Addresses comments.

    * Explicitly test flatten_metrics=True.

commit 7a40f91
Author: Dipannita Shaw <[email protected]>
Date:   Fri Jan 31 11:45:33 2025 -0800

    Add Goodput & Badput recording and monitoring support. (apple#783)

    * Code clean up

    * Add more testing

    * Fix docstrings

    * Remove recorder calls from trainer for now

    * Code cleanup gcp/measurement.py

    Co-authored-by: Ruoming Pang <[email protected]>

    * Code cleanup  common/measurement.py

    Co-authored-by: Ruoming Pang <[email protected]>

    * Fix pre commit errors

    * Adding more tests

    * Further clean up

    * Fix a test error

    ---------

    Co-authored-by: Ruoming Pang <[email protected]>

commit 031a7f3
Author: Mark Lee <[email protected]>
Date:   Thu Jan 30 20:19:12 2025 -0800

    Skipping empty grain batches during unbatch. (apple#961)

    * Skipping empty grain batches during unbatch.

    * Use a loop instead of recursion.

commit 795da33
Author: Hanzhi Zhou <[email protected]>
Date:   Thu Jan 30 07:17:16 2025 -0800

    Optimizer offloading through weight-only offload (apple#867)

    * Optimizer offloading

    * Style fix

    * Type fix

commit b1a1a5a
Author: Haoshuo Huang <[email protected]>
Date:   Wed Jan 29 21:44:15 2025 -0800

    Improve gcsfuse io (apple#959)

commit d76ef6f
Author: Hanzhi Zhou <[email protected]>
Date:   Wed Jan 29 15:10:13 2025 -0800

    SplashAttention performance tuning for v6e (apple#958)

    * SplashAttention tuning for v6e

    * Add import to fix pytype errors

commit 2d002e3
Author: Hanzhi Zhou <[email protected]>
Date:   Wed Jan 29 12:07:56 2025 -0800

    Use InputDispatcher for fuji models (apple#956)

    * Use dispatcher

    * Update golden configs

    * Remove logical feed indices

commit fad264b
Author: Mark Lee <[email protected]>
Date:   Tue Jan 28 10:41:54 2025 -0800

    Explicitly pass module outputs to metrics. (apple#953)

    * Explicitly pass module outputs to metrics.

    * Support and add checks for module/state updates.

    * Only flatten summaries.

commit 59508e3
Author: Hanzhi Zhou <[email protected]>
Date:   Tue Jan 28 10:34:52 2025 -0800

    Add v6e PCIe overload workaround flag (apple#955)

commit 028ecfd
Author: Haoshuo Huang <[email protected]>
Date:   Mon Jan 27 20:54:28 2025 -0800

    Fix GCSFUSE flags by setting resource limit. (apple#954)

commit 3e2c6dd
Author: Matthew Hopkins <[email protected]>
Date:   Mon Jan 27 14:56:42 2025 -0800

    update jax to 0.4.37 (apple#948)

    update BlockSpec usage in tpu_attention
    use TYPE_CHECKING for BuildDatasetFn in input_fake
    add todo for BuildDatasetFn

commit b125f00
Author: Hanzhi Zhou <[email protected]>
Date:   Mon Jan 27 11:29:23 2025 -0800

    Add v6e special meshes (apple#952)

    * Add v6e special mesh

    * Add v6e special mesh

    * Fix

    * Fix

commit a854738
Author: Firenze11 <[email protected]>
Date:   Mon Jan 27 09:17:46 2025 -0800

    Allow external positions to be inputed in RoPE embedding layer (apple#926)

    * Allow external positions to be inputed in RoPE embedding layer

    Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after `i_proj`. Unlike the implementation of current `RoFormerQKVLinear`, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by `MultiheadAttention`.

    * Update attention_test.py

    * Update dit.py

    * Update attention.py

    * Update attention_test.py

    * Update attention.py

    * Update dit.py

    * Update axlearn/common/attention.py

    Co-authored-by: Mark Lee <[email protected]>

    * respond to comments.

    Co-authored-by: Ruoming Pang <[email protected]>

    * Update attention.py

    * Update attention.py

    * Update attention.py

    ---------

    Co-authored-by: Mark Lee <[email protected]>
    Co-authored-by: Ruoming Pang <[email protected]>

commit 999401a
Author: qdavid1 <[email protected]>
Date:   Mon Jan 27 09:11:17 2025 -0800

    Update LoraFusedQKVLinear (apple#949)

commit 1c22688
Author: Mark Lee <[email protected]>
Date:   Sun Jan 26 04:51:02 2025 -0800

    Workaround module outputs being dropped. (apple#951)

commit 94c81cb
Author: Meng (Ethan) Li <[email protected]>
Date:   Fri Jan 24 11:01:45 2025 -0800

    Add link to github issue regarding kubernetes-32.0.0 (apple#947)

commit a6e0f4a
Author: Meng (Ethan) Li <[email protected]>
Date:   Fri Jan 24 08:40:25 2025 -0800

    Pin kubernetes pip version to 31.0.0 to fix client authentication error (apple#946)

commit 076521a
Author: Mark Lee <[email protected]>
Date:   Thu Jan 23 15:11:00 2025 -0800

    Forward input keys to decoder. (apple#944)

commit 30284c8
Author: Hanzhi Zhou <[email protected]>
Date:   Thu Jan 23 10:33:54 2025 -0800

    Legacy flash remat fix (apple#943)

    * Fix the same problem for legacy tpu attn

    * Fix

commit 6a9f980
Author: Mark Lee <[email protected]>
Date:   Thu Jan 23 09:20:46 2025 -0800

    Adds mesh rule for a3-megagpu-8g. (apple#936)

commit ac7a3ed
Author: Dongseong Hwang <[email protected]>
Date:   Thu Jan 23 08:15:27 2025 -0800

    Enabled running Pallas Flash Attention on CPU. (apple#922)

    Pallas supports CPU simulation (`interpret=True`), so we can use the same
    TPU Pallas kernel on CPU — making code debugging easier.

    This change lets the following unittests run on CPU as if they were on TPU,
    enabling easier testing and debugging:
    - `axlearn/common/flash_attention/tpu_attention_test.py`

    Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU.
    - `axlearn/common/flash_attention/gpu_attention_test.py`

    Now CI covers those tests on CPU as well.
    In M3 Max MacBook Pro, test coverages and processing time are as follows,
    * axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20)
    * axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s

commit 8ea85bd
Author: Hanzhi Zhou <[email protected]>
Date:   Wed Jan 22 09:51:15 2025 -0800

    Some fixes for flash remat (apple#942)

commit 185b1b5
Author: Chang Lan <[email protected]>
Date:   Tue Jan 21 11:21:08 2025 -0800

    Repeat KV heads in Flash Attention (apple#938)

    * Roll back '_repeat_kv_heads' change in Flash Attention

    Recent PR removed _repeat_kv_heads from Flash Attention for GQA optimization,
    in the hope to reduce HBM usage. However the actual HBM saving would be limited
    in the model-parallel setting, as the heads are already sharded across devices.
    It also introduces some limitation which breaks some of the existing sharding
    configurations.

    For example, let's say num_heads = 8 and num_kv_heads = 4. When we repeat KV heads,
    we can set the model axis as 8 so that each device will have only one Q, K, V head;
    Without repeat_kv_heads, the max value of model axis is 4, and each device will have
    2 Q heads as a result, increasing the actual HBM usage.

    * Repeat kv as necessary for sharding

    * Unit tests

    * Address comments.

commit 4678740
Author: Chang Lan <[email protected]>
Date:   Mon Jan 20 20:36:44 2025 -0800

    AOT compilation for v6e (apple#937)

commit 357bef6
Author: Mark Lee <[email protected]>
Date:   Mon Jan 20 20:23:39 2025 -0800

    Makes causal lm metrics configurable. (apple#934)

    * Makes causal lm metrics configurable.

    * Address review comments.

    * Make metrics required.

    * Update golden configs.

    * Removes PredictModel.

commit 16ca0c2
Author: Mark Lee <[email protected]>
Date:   Sun Jan 19 14:19:20 2025 -0800

    Supports flexible input partition specs. (apple#933)

    * Supports flexible input partition specs in causal lm.

    * Moves the input partitioning to Input.

    * Adds missing pytest marker.

    * Address review comments.

    * Rebase and update golden configs.

    * Fixes batch axis names and adds a test.

commit 9b75ef1
Author: Mark Lee <[email protected]>
Date:   Sun Jan 19 07:43:19 2025 -0800

    Avoid a top-level import of tokenizers. (apple#935)

commit 9996f34
Author: sychen52 <[email protected]>
Date:   Sat Jan 18 09:44:04 2025 -0800

    Add llama 3 tokenizer (apple#850)

    * Add llama 3 tokenizer

    add a new version called V3_TIKTOKEN.

    other edits based on suggestions.

    * Handle special tokens like other vocabularies.

    * use encode instead of encode_batch

commit ad14de3
Author: Haoshuo Huang <[email protected]>
Date:   Fri Jan 17 14:19:24 2025 -0800

    Add ReadOptions args to _make_autoregressive_inputs (apple#931)

    * Add ReadOptions args to _make_autoregressive_inputs

    * use read_options as args instead

commit 4858070
Author: Sam Stoelinga <[email protected]>
Date:   Fri Jan 17 13:54:05 2025 -0800

    improve GCS perf: Change resource limit to request (apple#851)

commit b0ee05e
Author: Bailin <[email protected]>
Date:   Fri Jan 17 22:53:00 2025 +0800

    Add Mamab2 and its Jamba variant (apple#839)

    * add mamab2

    * merge

    * unify init and prefill

    * adapt final changes

    ---------

    Co-authored-by: bailin_wang <[email protected]>

commit 1e25e4a
Author: Hanzhi Zhou <[email protected]>
Date:   Thu Jan 16 11:25:24 2025 -0800

    Cache AoT compilation result (apple#927)

    * Cache AoT compilation result

    * Fix comments

    * Fix

    * Fix

    * Fix

    * Fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants