Build Linux JAX Wheels (gfx94X-dcgpu, 3.12, dev) #1997
Workflow file for this run
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: Build Portable Linux JAX Wheels | |
| on: | |
| workflow_call: | |
| inputs: | |
| amdgpu_family: | |
| required: true | |
| type: string | |
| python_version: | |
| required: true | |
| type: string | |
| release_type: | |
| description: The type of release to build ("dev", "nightly", or "prerelease"). All developer-triggered jobs should use "dev"! | |
| required: true | |
| type: string | |
| s3_subdir: | |
| description: S3 subdirectory, not including the GPU-family | |
| required: true | |
| type: string | |
| s3_staging_subdir: | |
| description: S3 staging subdirectory, not including the GPU-family | |
| required: true | |
| type: string | |
| rocm_version: | |
| description: ROCm version to install | |
| type: string | |
| tar_url: | |
| description: URL to TheRock tarball to build against | |
| type: string | |
| cloudfront_url: | |
| description: CloudFront URL pointing to Python index | |
| required: true | |
| type: string | |
| cloudfront_staging_url: | |
| description: CloudFront base URL pointing to staging Python index | |
| required: true | |
| type: string | |
| repository: | |
| description: "Repository to checkout. Defaults to `ROCm/TheRock`." | |
| type: string | |
| default: "ROCm/TheRock" | |
| ref: | |
| description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow." | |
| type: string | |
| jax_ref: | |
| description: "rocm-jax repository ref/branch to check out" | |
| type: string | |
| default: "rocm-jaxlib-v0.8.0-fixdevtar" | |
| workflow_dispatch: | |
| inputs: | |
| amdgpu_family: | |
| type: choice | |
| options: | |
| - gfx101X-dgpu | |
| - gfx103X-dgpu | |
| - gfx110X-all | |
| - gfx1150 | |
| - gfx1151 | |
| - gfx1152 | |
| - gfx1153 | |
| - gfx120X-all | |
| - gfx90X-dcgpu | |
| - gfx94X-dcgpu | |
| - gfx950-dcgpu | |
| default: gfx94X-dcgpu | |
| python_version: | |
| required: true | |
| type: string | |
| default: "3.12" | |
| release_type: | |
| type: choice | |
| description: Type of release to create. All developer-triggered jobs should use "dev"! | |
| options: | |
| - dev | |
| - nightly | |
| - prerelease | |
| default: dev | |
| s3_subdir: | |
| description: S3 subdirectory, not including the GPU-family | |
| type: string | |
| default: "v2" | |
| s3_staging_subdir: | |
| description: S3 staging subdirectory, not including the GPU-family | |
| type: string | |
| default: "v2-staging" | |
| rocm_version: | |
| description: ROCm version to install | |
| type: string | |
| tar_url: | |
| description: URL to TheRock tarball to build against | |
| type: string | |
| cloudfront_url: | |
| description: CloudFront base URL pointing to Python index | |
| type: string | |
| default: "https://rocm.devreleases.amd.com/v2" | |
| cloudfront_staging_url: | |
| description: CloudFront base URL pointing to staging Python index | |
| type: string | |
| default: "https://rocm.devreleases.amd.com/v2-staging" | |
| ref: | |
| description: TheRock repository ref/branch to check out | |
| type: string | |
| default: "" | |
| jax_ref: | |
| description: rocm-jax repository ref/branch to check out | |
| type: string | |
| default: rocm-jaxlib-v0.8.2 | |
| permissions: | |
| id-token: write | |
| contents: read | |
| run-name: Build Linux JAX Wheels (${{ inputs.amdgpu_family }}, ${{ inputs.python_version }}, ${{ inputs.release_type }}) | |
| jobs: | |
| build_jax_wheels: | |
| strategy: | |
| matrix: | |
| jax_ref: [rocm-jaxlib-v0.8.2] | |
| name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_version }} | |
| runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }} | |
| env: | |
| PACKAGE_DIST_DIR: ${{ github.workspace }}/jax/jax_rocm_plugin/wheelhouse | |
| S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python" | |
| outputs: | |
| cp_version: ${{ env.cp_version }} | |
| jax_version: ${{ steps.write_jax_versions.outputs.jax_version }} | |
| jaxlib_version: ${{ steps.write_jax_versions.outputs.jaxlib_version }} | |
| jax_plugin_version: ${{ steps.write_jax_versions.outputs.jax_plugin_version }} | |
| jax_pjrt_version: ${{ steps.write_jax_versions.outputs.jax_pjrt_version }} | |
| jax_ref: ${{ matrix.jax_ref }} | |
| steps: | |
| - name: Checkout TheRock | |
| uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 | |
| - name: Checkout JAX | |
| uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 | |
| with: | |
| path: jax | |
| repository: rocm/rocm-jax | |
| ref: ${{ matrix.jax_ref }} | |
| - name: Configure Git Identity | |
| run: | | |
| git config --global user.name "therockbot" | |
| git config --global user.email "[email protected]" | |
| - name: "Setting up Python" | |
| uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 | |
| with: | |
| python-version: ${{ inputs.python_version }} | |
| - name: Select Python version | |
| run: | | |
| python build_tools/github_actions/python_to_cp_version.py \ | |
| --python-version ${{ inputs.python_version }} | |
| - name: Build JAX Wheels | |
| env: | |
| ROCM_VERSION: ${{ inputs.rocm_version }} | |
| run: | | |
| ls -lah | |
| pushd jax | |
| python3 build/ci_build \ | |
| --compiler=clang \ | |
| --python-versions="${{ inputs.python_version }}" \ | |
| --rocm-version="${ROCM_VERSION}" \ | |
| --therock-path="${{ inputs.tar_url }}" \ | |
| dist_wheels | |
| - name: Extract JAX versions from built wheels | |
| id: write_jax_versions | |
| run: | | |
| python ./build_tools/github_actions/write_jax_versions.py \ | |
| --dist-dir ${{ env.PACKAGE_DIST_DIR }} | |
| - name: Install AWS CLI | |
| if: always() | |
| run: bash ./dockerfiles/install_awscli.sh | |
| - name: Configure AWS Credentials | |
| if: always() | |
| uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0 | |
| with: | |
| aws-region: us-east-2 | |
| role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }} | |
| - name: Upload wheels to S3 | |
| if: ${{ github.repository_owner == 'ROCm' }} | |
| run: | | |
| aws s3 cp ${{ env.PACKAGE_DIST_DIR }}/ s3://${{ env.S3_BUCKET_PY }}/${{ inputs.s3_staging_subdir }}/${{ inputs.amdgpu_family }}/ \ | |
| --recursive --exclude "*" --include "*.whl" | |
| - name: (Re-)Generate Python package release index | |
| if: ${{ github.repository_owner == 'ROCm' }} | |
| run: | | |
| python3 -m venv .venv | |
| source .venv/bin/activate | |
| pip3 install boto3 packaging | |
| python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_staging_subdir }}/${{ inputs.amdgpu_family }} | |
| generate_target_to_run: | |
| name: Generate target_to_run | |
| runs-on: ubuntu-24.04 | |
| outputs: | |
| test_runs_on: ${{ steps.configure.outputs.test-runs-on }} | |
| bypass_tests_for_releases: ${{ steps.configure.outputs.bypass_tests_for_releases }} | |
| steps: | |
| - name: Checking out repository | |
| uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 | |
| with: | |
| repository: ${{ inputs.repository || github.repository }} | |
| ref: ${{ inputs.ref || '' }} | |
| - name: Generating target to run | |
| id: configure | |
| env: | |
| TARGET: ${{ inputs.amdgpu_family }} | |
| PLATFORM: "linux" | |
| # Variable comes from ROCm organization variable 'ROCM_THEROCK_TEST_RUNNERS' | |
| ROCM_THEROCK_TEST_RUNNERS: ${{ vars.ROCM_THEROCK_TEST_RUNNERS }} | |
| LOAD_TEST_RUNNERS_FROM_VAR: false | |
| run: python ./build_tools/github_actions/configure_target_run.py | |
| test_jax_wheels: | |
| name: Test JAX wheels | ${{ inputs.amdgpu_family }} | ${{ needs.generate_target_to_run.outputs.test_runs_on }} | |
| needs: [build_jax_wheels, generate_target_to_run] | |
| permissions: | |
| contents: read | |
| packages: read | |
| uses: ./.github/workflows/test_linux_jax_wheels.yml | |
| with: | |
| amdgpu_family: ${{ inputs.amdgpu_family }} | |
| release_type: ${{ inputs.release_type }} | |
| package_index_url: ${{ inputs.cloudfront_staging_url }} | |
| rocm_version: ${{ inputs.rocm_version }} | |
| tar_url: ${{ inputs.tar_url }} | |
| python_version: ${{ inputs.python_version }} | |
| repository: ${{ inputs.repository || github.repository }} | |
| ref: ${{ inputs.ref || '' }} | |
| jax_ref: ${{ needs.build_jax_wheels.outputs.jax_ref }} | |
| jax_version: ${{ needs.build_jax_wheels.outputs.jax_version }} | |
| jaxlib_version: ${{ needs.build_jax_wheels.outputs.jaxlib_version }} | |
| test_runs_on: ${{ needs.generate_target_to_run.outputs.test_runs_on }} | |
| upload_jax_wheels: | |
| name: Release JAX Wheels to S3 | |
| needs: [build_jax_wheels, generate_target_to_run, test_jax_wheels] | |
| if: ${{ !cancelled() }} | |
| runs-on: ubuntu-24.04 | |
| env: | |
| S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python" | |
| JAXLIB_VERSION: "${{ needs.build_jax_wheels.outputs.jaxlib_version }}" | |
| JAX_PLUGIN_VERSION: "${{ needs.build_jax_wheels.outputs.jax_plugin_version }}" | |
| JAX_PJRT_VERSION: "${{ needs.build_jax_wheels.outputs.jax_pjrt_version }}" | |
| CP_VERSION: "${{ needs.build_jax_wheels.outputs.cp_version }}" | |
| steps: | |
| - name: Checkout | |
| uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 | |
| with: | |
| repository: ${{ inputs.repository || github.repository }} | |
| ref: ${{ inputs.ref || '' }} | |
| - name: Configure AWS Credentials | |
| if: always() | |
| uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0 | |
| with: | |
| aws-region: us-east-2 | |
| role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }} | |
| - name: Determine upload flag | |
| env: | |
| BUILD_RESULT: ${{ needs.build_jax_wheels.result }} | |
| TEST_RESULT: ${{ needs.test_jax_wheels.result }} | |
| TEST_RUNS_ON: ${{ needs.generate_target_to_run.outputs.test_runs_on }} | |
| BYPASS_TESTS_FOR_RELEASES: ${{ needs.generate_target_to_run.outputs.bypass_tests_for_releases }} | |
| run: python ./build_tools/github_actions/promote_wheels_based_on_policy.py | |
| - name: Copy JAX wheels from staging to release S3 | |
| if: ${{ env.upload == 'true' }} | |
| run: | | |
| echo "Copying exact tested wheels to release S3 bucket..." | |
| echo " JAXLIB_VERSION=${JAXLIB_VERSION}" | |
| echo " JAX_PLUGIN_VERSION=${JAX_PLUGIN_VERSION}" | |
| echo " JAX_PJRT_VERSION=${JAX_PJRT_VERSION}" | |
| aws s3 cp \ | |
| s3://${S3_BUCKET_PY}/${{ inputs.s3_staging_subdir }}/${{ inputs.amdgpu_family }}/ \ | |
| s3://${S3_BUCKET_PY}/${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}/ \ | |
| --recursive \ | |
| --exclude "*" \ | |
| --include "jaxlib-${JAXLIB_VERSION}-${CP_VERSION}-manylinux_2_27_x86_64.whl" \ | |
| --include "jax_rocm7_plugin-${JAX_PLUGIN_VERSION}-${CP_VERSION}-manylinux_2_28_x86_64.whl" \ | |
| --include "jax_rocm7_pjrt-${JAX_PJRT_VERSION}-py3-none-manylinux_2_28_x86_64.whl" | |
| - name: (Re-)Generate Python package release index | |
| if: ${{ env.upload == 'true' }} | |
| env: | |
| # Environment variables to be set for `manage.py` | |
| CUSTOM_PREFIX: "${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}" | |
| run: | | |
| pip install boto3 packaging | |
| python ./build_tools/third_party/s3_management/manage.py ${{ env.CUSTOM_PREFIX }} |