Skip to content

I've added unit tests for critical conversion, layer, and inference c… #1804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

stingram
Copy link
Collaborator

@stingram stingram commented Jun 4, 2025

…omponents.

This commit introduces unit tests for several key modules that previously lacked them:

  1. Checkpoint Conversion (MaxText/checkpoint_conversion_utils.py):

    • I refactored core logic from convert_gemma_chkpt.py into the new checkpoint_conversion_utils.py to facilitate testing without heavy dependencies.
    • I then added tests in MaxText/tests/convert_gemma_chkpt_test.py covering parameter name mapping, tensor transformations (scaling, transposition), and model-specific (MQA vs. MHA) logic for Gemma checkpoint conversion.
  2. Gemma Layer (MaxText/layers/gemma.py):

    • Next, I added tests in MaxText/tests/layers/gemma_layer_test.py for the GemmaDecoderLayer.
    • These tests verify correct output shapes, dropout behavior (deterministic vs. non-deterministic), and handling of different model modes (TRAIN, PREFILL, AUTOREGRESSIVE), including proper KV cache initialization.
    • I overcame significant challenges in setting up a valid Config object for layer tests by using minimal pyconfig.initialize arguments and manual post-initialization configuration.
  3. Paged Attention (MaxText/inference/paged_attention.py):

    • I also added tests in MaxText/tests/inference/paged_attention_test.py for the PagedAttentionOp.
    • These tests verify output shapes for both prefill and decode modes.
    • A basic attention value calculation is tested for the prefill path.
    • The decode path shape test is marked as an @unittest.expectedFailure on CPU due to Pallas kernel incompatibility, documenting this current limitation.
    • The test setup for PagedAttentionOp also involved careful Config and mock PageState creation.

These tests improve the robustness and maintainability of the MaxText codebase by providing automated verification for these critical components.

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

…omponents.

This commit introduces unit tests for several key modules that previously lacked them:

1.  **Checkpoint Conversion (`MaxText/checkpoint_conversion_utils.py`):**
    *   I refactored core logic from `convert_gemma_chkpt.py` into the new `checkpoint_conversion_utils.py` to facilitate testing without heavy dependencies.
    *   I then added tests in `MaxText/tests/convert_gemma_chkpt_test.py` covering parameter name mapping, tensor transformations (scaling, transposition), and model-specific (MQA vs. MHA) logic for Gemma checkpoint conversion.

2.  **Gemma Layer (`MaxText/layers/gemma.py`):**
    *   Next, I added tests in `MaxText/tests/layers/gemma_layer_test.py` for the `GemmaDecoderLayer`.
    *   These tests verify correct output shapes, dropout behavior (deterministic vs. non-deterministic), and handling of different model modes (TRAIN, PREFILL, AUTOREGRESSIVE), including proper KV cache initialization.
    *   I overcame significant challenges in setting up a valid `Config` object for layer tests by using minimal `pyconfig.initialize` arguments and manual post-initialization configuration.

3.  **Paged Attention (`MaxText/inference/paged_attention.py`):**
    *   I also added tests in `MaxText/tests/inference/paged_attention_test.py` for the `PagedAttentionOp`.
    *   These tests verify output shapes for both prefill and decode modes.
    *   A basic attention value calculation is tested for the prefill path.
    *   The decode path shape test is marked as an `@unittest.expectedFailure` on CPU due to Pallas kernel incompatibility, documenting this current limitation.
    *   The test setup for `PagedAttentionOp` also involved careful `Config` and mock `PageState` creation.

These tests improve the robustness and maintainability of the MaxText codebase by providing automated verification for these critical components.
This involved using pyink (from code_style.sh) to automatically format the code and resolve linting errors that came up.

The formatting changes make sure the style is consistent for the new unit tests and related utility files:
- MaxText/checkpoint_conversion_utils.py
- MaxText/tests/convert_gemma_chkpt_test.py
- MaxText/tests/layers/gemma_layer_test.py
- MaxText/tests/inference/paged_attention_test.py

I also reformatted several existing files with pyink. The underlying logic of your tests and code hasn't changed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant