Skip to content

Build Linux JAX Wheels (gfx94X-dcgpu, 3.12, dev) #1997

Build Linux JAX Wheels (gfx94X-dcgpu, 3.12, dev)

Build Linux JAX Wheels (gfx94X-dcgpu, 3.12, dev) #1997

name: Build Portable Linux JAX Wheels
on:
workflow_call:
inputs:
amdgpu_family:
required: true
type: string
python_version:
required: true
type: string
release_type:
description: The type of release to build ("dev", "nightly", or "prerelease"). All developer-triggered jobs should use "dev"!
required: true
type: string
s3_subdir:
description: S3 subdirectory, not including the GPU-family
required: true
type: string
s3_staging_subdir:
description: S3 staging subdirectory, not including the GPU-family
required: true
type: string
rocm_version:
description: ROCm version to install
type: string
tar_url:
description: URL to TheRock tarball to build against
type: string
cloudfront_url:
description: CloudFront URL pointing to Python index
required: true
type: string
cloudfront_staging_url:
description: CloudFront base URL pointing to staging Python index
required: true
type: string
repository:
description: "Repository to checkout. Defaults to `ROCm/TheRock`."
type: string
default: "ROCm/TheRock"
ref:
description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow."
type: string
jax_ref:
description: "rocm-jax repository ref/branch to check out"
type: string
default: "rocm-jaxlib-v0.8.0-fixdevtar"
workflow_dispatch:
inputs:
amdgpu_family:
type: choice
options:
- gfx101X-dgpu
- gfx103X-dgpu
- gfx110X-all
- gfx1150
- gfx1151
- gfx1152
- gfx1153
- gfx120X-all
- gfx90X-dcgpu
- gfx94X-dcgpu
- gfx950-dcgpu
default: gfx94X-dcgpu
python_version:
required: true
type: string
default: "3.12"
release_type:
type: choice
description: Type of release to create. All developer-triggered jobs should use "dev"!
options:
- dev
- nightly
- prerelease
default: dev
s3_subdir:
description: S3 subdirectory, not including the GPU-family
type: string
default: "v2"
s3_staging_subdir:
description: S3 staging subdirectory, not including the GPU-family
type: string
default: "v2-staging"
rocm_version:
description: ROCm version to install
type: string
tar_url:
description: URL to TheRock tarball to build against
type: string
cloudfront_url:
description: CloudFront base URL pointing to Python index
type: string
default: "https://rocm.devreleases.amd.com/v2"
cloudfront_staging_url:
description: CloudFront base URL pointing to staging Python index
type: string
default: "https://rocm.devreleases.amd.com/v2-staging"
ref:
description: TheRock repository ref/branch to check out
type: string
default: ""
jax_ref:
description: rocm-jax repository ref/branch to check out
type: string
default: rocm-jaxlib-v0.8.2
permissions:
id-token: write
contents: read
run-name: Build Linux JAX Wheels (${{ inputs.amdgpu_family }}, ${{ inputs.python_version }}, ${{ inputs.release_type }})
jobs:
build_jax_wheels:
strategy:
matrix:
jax_ref: [rocm-jaxlib-v0.8.2]
name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_version }}
runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }}
env:
PACKAGE_DIST_DIR: ${{ github.workspace }}/jax/jax_rocm_plugin/wheelhouse
S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python"
outputs:
cp_version: ${{ env.cp_version }}
jax_version: ${{ steps.write_jax_versions.outputs.jax_version }}
jaxlib_version: ${{ steps.write_jax_versions.outputs.jaxlib_version }}
jax_plugin_version: ${{ steps.write_jax_versions.outputs.jax_plugin_version }}
jax_pjrt_version: ${{ steps.write_jax_versions.outputs.jax_pjrt_version }}
jax_ref: ${{ matrix.jax_ref }}
steps:
- name: Checkout TheRock
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Checkout JAX
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
path: jax
repository: rocm/rocm-jax
ref: ${{ matrix.jax_ref }}
- name: Configure Git Identity
run: |
git config --global user.name "therockbot"
git config --global user.email "[email protected]"
- name: "Setting up Python"
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
python-version: ${{ inputs.python_version }}
- name: Select Python version
run: |
python build_tools/github_actions/python_to_cp_version.py \
--python-version ${{ inputs.python_version }}
- name: Build JAX Wheels
env:
ROCM_VERSION: ${{ inputs.rocm_version }}
run: |
ls -lah
pushd jax
python3 build/ci_build \
--compiler=clang \
--python-versions="${{ inputs.python_version }}" \
--rocm-version="${ROCM_VERSION}" \
--therock-path="${{ inputs.tar_url }}" \
dist_wheels
- name: Extract JAX versions from built wheels
id: write_jax_versions
run: |
python ./build_tools/github_actions/write_jax_versions.py \
--dist-dir ${{ env.PACKAGE_DIST_DIR }}
- name: Install AWS CLI
if: always()
run: bash ./dockerfiles/install_awscli.sh
- name: Configure AWS Credentials
if: always()
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0
with:
aws-region: us-east-2
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}
- name: Upload wheels to S3
if: ${{ github.repository_owner == 'ROCm' }}
run: |
aws s3 cp ${{ env.PACKAGE_DIST_DIR }}/ s3://${{ env.S3_BUCKET_PY }}/${{ inputs.s3_staging_subdir }}/${{ inputs.amdgpu_family }}/ \
--recursive --exclude "*" --include "*.whl"
- name: (Re-)Generate Python package release index
if: ${{ github.repository_owner == 'ROCm' }}
run: |
python3 -m venv .venv
source .venv/bin/activate
pip3 install boto3 packaging
python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_staging_subdir }}/${{ inputs.amdgpu_family }}
generate_target_to_run:
name: Generate target_to_run
runs-on: ubuntu-24.04
outputs:
test_runs_on: ${{ steps.configure.outputs.test-runs-on }}
bypass_tests_for_releases: ${{ steps.configure.outputs.bypass_tests_for_releases }}
steps:
- name: Checking out repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
repository: ${{ inputs.repository || github.repository }}
ref: ${{ inputs.ref || '' }}
- name: Generating target to run
id: configure
env:
TARGET: ${{ inputs.amdgpu_family }}
PLATFORM: "linux"
# Variable comes from ROCm organization variable 'ROCM_THEROCK_TEST_RUNNERS'
ROCM_THEROCK_TEST_RUNNERS: ${{ vars.ROCM_THEROCK_TEST_RUNNERS }}
LOAD_TEST_RUNNERS_FROM_VAR: false
run: python ./build_tools/github_actions/configure_target_run.py
test_jax_wheels:
name: Test JAX wheels | ${{ inputs.amdgpu_family }} | ${{ needs.generate_target_to_run.outputs.test_runs_on }}
needs: [build_jax_wheels, generate_target_to_run]
permissions:
contents: read
packages: read
uses: ./.github/workflows/test_linux_jax_wheels.yml
with:
amdgpu_family: ${{ inputs.amdgpu_family }}
release_type: ${{ inputs.release_type }}
package_index_url: ${{ inputs.cloudfront_staging_url }}
rocm_version: ${{ inputs.rocm_version }}
tar_url: ${{ inputs.tar_url }}
python_version: ${{ inputs.python_version }}
repository: ${{ inputs.repository || github.repository }}
ref: ${{ inputs.ref || '' }}
jax_ref: ${{ needs.build_jax_wheels.outputs.jax_ref }}
jax_version: ${{ needs.build_jax_wheels.outputs.jax_version }}
jaxlib_version: ${{ needs.build_jax_wheels.outputs.jaxlib_version }}
test_runs_on: ${{ needs.generate_target_to_run.outputs.test_runs_on }}
upload_jax_wheels:
name: Release JAX Wheels to S3
needs: [build_jax_wheels, generate_target_to_run, test_jax_wheels]
if: ${{ !cancelled() }}
runs-on: ubuntu-24.04
env:
S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python"
JAXLIB_VERSION: "${{ needs.build_jax_wheels.outputs.jaxlib_version }}"
JAX_PLUGIN_VERSION: "${{ needs.build_jax_wheels.outputs.jax_plugin_version }}"
JAX_PJRT_VERSION: "${{ needs.build_jax_wheels.outputs.jax_pjrt_version }}"
CP_VERSION: "${{ needs.build_jax_wheels.outputs.cp_version }}"
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
repository: ${{ inputs.repository || github.repository }}
ref: ${{ inputs.ref || '' }}
- name: Configure AWS Credentials
if: always()
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0
with:
aws-region: us-east-2
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}
- name: Determine upload flag
env:
BUILD_RESULT: ${{ needs.build_jax_wheels.result }}
TEST_RESULT: ${{ needs.test_jax_wheels.result }}
TEST_RUNS_ON: ${{ needs.generate_target_to_run.outputs.test_runs_on }}
BYPASS_TESTS_FOR_RELEASES: ${{ needs.generate_target_to_run.outputs.bypass_tests_for_releases }}
run: python ./build_tools/github_actions/promote_wheels_based_on_policy.py
- name: Copy JAX wheels from staging to release S3
if: ${{ env.upload == 'true' }}
run: |
echo "Copying exact tested wheels to release S3 bucket..."
echo " JAXLIB_VERSION=${JAXLIB_VERSION}"
echo " JAX_PLUGIN_VERSION=${JAX_PLUGIN_VERSION}"
echo " JAX_PJRT_VERSION=${JAX_PJRT_VERSION}"
aws s3 cp \
s3://${S3_BUCKET_PY}/${{ inputs.s3_staging_subdir }}/${{ inputs.amdgpu_family }}/ \
s3://${S3_BUCKET_PY}/${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}/ \
--recursive \
--exclude "*" \
--include "jaxlib-${JAXLIB_VERSION}-${CP_VERSION}-manylinux_2_27_x86_64.whl" \
--include "jax_rocm7_plugin-${JAX_PLUGIN_VERSION}-${CP_VERSION}-manylinux_2_28_x86_64.whl" \
--include "jax_rocm7_pjrt-${JAX_PJRT_VERSION}-py3-none-manylinux_2_28_x86_64.whl"
- name: (Re-)Generate Python package release index
if: ${{ env.upload == 'true' }}
env:
# Environment variables to be set for `manage.py`
CUSTOM_PREFIX: "${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}"
run: |
pip install boto3 packaging
python ./build_tools/third_party/s3_management/manage.py ${{ env.CUSTOM_PREFIX }}