Skip to content

Commit dd796d2

Browse files
authored
Merge branch 'main' into jechrist/bumpJaxTo082
2 parents 63f45f3 + 7d121cf commit dd796d2

File tree

139 files changed

+5120
-1026
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

139 files changed

+5120
-1026
lines changed

.github/actions/setup_test_environment/action.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ runs:
5858
shell: bash
5959
run: ./dockerfiles/install_awscli.sh
6060

61-
- name: Install the AWS tool
61+
- name: Install Choco tools
6262
if: ${{ runner.os == 'Windows' }}
6363
shell: bash
6464
run: |
65-
choco install --no-progress -y awscli
65+
choco install --no-progress -y ninja awscli
6666
echo "$PATH;C:\Program Files\Amazon\AWSCLIV2" >> $GITHUB_PATH
6767
6868
- name: Download and Unpack Artifacts

.github/workflows/build_linux_jax_wheels.yml

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ on:
4242
ref:
4343
description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow."
4444
type: string
45+
jax_ref:
46+
description: "rocm-jax repository ref/branch to check out"
47+
type: string
48+
default: "rocm-jaxlib-v0.8.0-fixdevtar"
4549
workflow_dispatch:
4650
inputs:
4751
amdgpu_family:
@@ -93,6 +97,10 @@ on:
9397
description: CloudFront base URL pointing to staging Python index
9498
type: string
9599
default: "https://rocm.devreleases.amd.com/v2-staging"
100+
ref:
101+
description: TheRock repository ref/branch to check out
102+
type: string
103+
default: ""
96104
jax_ref:
97105
description: rocm-jax repository ref/branch to check out
98106
type: string
@@ -116,7 +124,11 @@ jobs:
116124
S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python"
117125
outputs:
118126
cp_version: ${{ env.cp_version }}
119-
jax_version: ${{ steps.extract_jax_version.outputs.jax_version }}
127+
jax_version: ${{ steps.write_jax_versions.outputs.jax_version }}
128+
jaxlib_version: ${{ steps.write_jax_versions.outputs.jaxlib_version }}
129+
jax_plugin_version: ${{ steps.write_jax_versions.outputs.jax_plugin_version }}
130+
jax_pjrt_version: ${{ steps.write_jax_versions.outputs.jax_pjrt_version }}
131+
jax_ref: ${{ matrix.jax_ref }}
120132
steps:
121133
- name: Checkout TheRock
122134
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -156,25 +168,19 @@ jobs:
156168
--therock-path="${{ inputs.tar_url }}" \
157169
dist_wheels
158170
159-
- name: Extract JAX version
160-
id: extract_jax_version
171+
- name: Extract JAX versions from built wheels
172+
id: write_jax_versions
161173
run: |
162-
# Extract JAX version from requirements.txt (e.g., "jax==0.8.2")
163-
# Remove all whitespace from requirements.txt to simplify parsing
164-
# Search for lines starting with "jax==" or "jaxlib==" followed by version (excluding comments)
165-
# Extract the version number by splitting on '=' and taking the 3rd field
166-
# [^#]+ matches one or more characters that are NOT '#', ensuring we stop before any inline comments
167-
JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \
168-
| grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3)
169-
echo "jax_version=$JAX_VERSION" >> "$GITHUB_OUTPUT"
174+
python ./build_tools/github_actions/write_jax_versions.py \
175+
--dist-dir ${{ env.PACKAGE_DIST_DIR }}
170176
171177
- name: Install AWS CLI
172178
if: always()
173179
run: bash ./dockerfiles/install_awscli.sh
174180

175181
- name: Configure AWS Credentials
176182
if: always()
177-
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 # v5.1.1
183+
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0
178184
with:
179185
aws-region: us-east-2
180186
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}
@@ -226,14 +232,15 @@ jobs:
226232
with:
227233
amdgpu_family: ${{ inputs.amdgpu_family }}
228234
release_type: ${{ inputs.release_type }}
229-
s3_subdir: ${{ inputs.s3_subdir }}
230235
package_index_url: ${{ inputs.cloudfront_staging_url }}
231236
rocm_version: ${{ inputs.rocm_version }}
232237
tar_url: ${{ inputs.tar_url }}
233238
python_version: ${{ inputs.python_version }}
234239
repository: ${{ inputs.repository || github.repository }}
235240
ref: ${{ inputs.ref || '' }}
236-
jax_ref: ${{ inputs.jax_ref }}
241+
jax_ref: ${{ needs.build_jax_wheels.outputs.jax_ref }}
242+
jax_version: ${{ needs.build_jax_wheels.outputs.jax_version }}
243+
jaxlib_version: ${{ needs.build_jax_wheels.outputs.jaxlib_version }}
237244
test_runs_on: ${{ needs.generate_target_to_run.outputs.test_runs_on }}
238245

239246
upload_jax_wheels:
@@ -243,8 +250,9 @@ jobs:
243250
runs-on: ubuntu-24.04
244251
env:
245252
S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python"
246-
JAX_VERSION: "${{ needs.build_jax_wheels.outputs.jax_version }}"
247-
ROCM_VERSION: "${{ inputs.rocm_version }}"
253+
JAXLIB_VERSION: "${{ needs.build_jax_wheels.outputs.jaxlib_version }}"
254+
JAX_PLUGIN_VERSION: "${{ needs.build_jax_wheels.outputs.jax_plugin_version }}"
255+
JAX_PJRT_VERSION: "${{ needs.build_jax_wheels.outputs.jax_pjrt_version }}"
248256
CP_VERSION: "${{ needs.build_jax_wheels.outputs.cp_version }}"
249257

250258
steps:
@@ -256,7 +264,7 @@ jobs:
256264

257265
- name: Configure AWS Credentials
258266
if: always()
259-
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 # v5.1.1
267+
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0
260268
with:
261269
aws-region: us-east-2
262270
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}
@@ -273,14 +281,17 @@ jobs:
273281
if: ${{ env.upload == 'true' }}
274282
run: |
275283
echo "Copying exact tested wheels to release S3 bucket..."
284+
echo " JAXLIB_VERSION=${JAXLIB_VERSION}"
285+
echo " JAX_PLUGIN_VERSION=${JAX_PLUGIN_VERSION}"
286+
echo " JAX_PJRT_VERSION=${JAX_PJRT_VERSION}"
276287
aws s3 cp \
277288
s3://${S3_BUCKET_PY}/${{ inputs.s3_staging_subdir }}/${{ inputs.amdgpu_family }}/ \
278289
s3://${S3_BUCKET_PY}/${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}/ \
279290
--recursive \
280291
--exclude "*" \
281-
--include "jaxlib-${JAX_VERSION}+rocm${ROCM_VERSION}-${CP_VERSION}-manylinux_2_27_x86_64.whl" \
282-
--include "jax_rocm7_plugin-${JAX_VERSION}+rocm${ROCM_VERSION}-${CP_VERSION}-manylinux_2_28_x86_64.whl" \
283-
--include "jax_rocm7_pjrt-${JAX_VERSION}+rocm${ROCM_VERSION}-py3-none-manylinux_2_28_x86_64.whl"
292+
--include "jaxlib-${JAXLIB_VERSION}-${CP_VERSION}-manylinux_2_27_x86_64.whl" \
293+
--include "jax_rocm7_plugin-${JAX_PLUGIN_VERSION}-${CP_VERSION}-manylinux_2_28_x86_64.whl" \
294+
--include "jax_rocm7_pjrt-${JAX_PJRT_VERSION}-py3-none-manylinux_2_28_x86_64.whl"
284295
285296
- name: (Re-)Generate Python package release index
286297
if: ${{ env.upload == 'true' }}

.github/workflows/build_native_linux_packages.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ on:
2020
required: true
2121
type: string
2222
package_suffix:
23-
description: The suffix to be added to package name (asan, static or rpath).
23+
description: The suffix to be added to package name (asan, tsan, static or rpath).
2424
required: false
2525
type: string
2626
release_type:
@@ -48,7 +48,7 @@ on:
4848
- deb
4949
default: "rpm"
5050
package_suffix:
51-
description: The suffix to be added to package name (asan, static or rpath).
51+
description: The suffix to be added to package name (asan, tsan, static or rpath).
5252
type: string
5353
required: false
5454
release_type:
@@ -120,7 +120,7 @@ jobs:
120120
run: bash ./dockerfiles/install_awscli.sh
121121

122122
- name: Configure AWS Credentials for non-forked repos
123-
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 # v5.1.1
123+
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0
124124
with:
125125
aws-region: us-east-2
126126
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}

.github/workflows/build_portable_linux_artifacts.yml

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ on:
1111
default: gfx94X-dcgpu
1212
build_variant_label:
1313
type: string
14-
description: "A label for the build variant (ex: 'release', 'asan')"
14+
description: "A label for the build variant (ex: 'release', 'asan', 'tsan')"
1515
default: "release"
1616
build_variant_suffix:
1717
type: string
@@ -59,15 +59,17 @@ jobs:
5959
build_portable_linux_artifacts:
6060
name: Build (xfail ${{ inputs.expect_failure }})
6161
# azure-linux-scale-rocm are used for regular CI builds
62-
# azure-linux-scale-rocm-heavy are used for CI builds that require more resources (ex: ASAN builds)
63-
runs-on: ${{ inputs.build_variant_label == 'asan' && 'azure-linux-u2404-hx176-cpu-rocm' || 'azure-linux-scale-rocm' }}
62+
# azure-linux-scale-rocm-heavy are used for CI builds that require more resources (ex: ASAN and TSAN builds)
63+
runs-on: ${{ (contains(inputs.build_variant_label, 'asan') || contains(inputs.build_variant_label, 'tsan')) && 'azure-linux-scale-rocm-heavy' || 'azure-linux-scale-rocm' }}
6464
continue-on-error: ${{ inputs.expect_failure }}
6565
timeout-minutes: 720 # 12 hour timeout
6666
permissions:
6767
id-token: write
6868
container:
6969
image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:db2b63f938941dde2abc80b734e64b45b9995a282896d513a0f3525d4591d6cb
70-
options: -v /runner/config:/home/awsconfig/
70+
# --cap-add=SYS_PTRACE : to enable ptrace insided the build container for tsan builds
71+
# --security-opt seccomp=unconfined : to disable the system call filtering for tsan builds
72+
options: -v /runner/config:/home/awsconfig/ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined
7173
env:
7274
AWS_SHARED_CREDENTIALS_FILE: /home/awsconfig/credentials.ini
7375
CACHE_DIR: ${{ github.workspace }}/.container-cache
@@ -105,10 +107,6 @@ jobs:
105107
run: |
106108
./build_tools/health_status.py
107109
108-
- name: Test build_tools
109-
run: |
110-
python -m pytest build_tools/tests build_tools/github_actions/tests
111-
112110
- name: Fetch sources
113111
timeout-minutes: 30
114112
run: |
@@ -166,7 +164,7 @@ jobs:
166164
167165
- name: Configure AWS Credentials for non-forked repos
168166
if: ${{ always() && !github.event.pull_request.head.repo.fork }}
169-
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 # v5.1.1
167+
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0
170168
with:
171169
aws-region: us-east-2
172170
role-to-assume: arn:aws:iam::692859939525:role/therock-ci

.github/workflows/build_portable_linux_python_packages.yml

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -27,66 +27,71 @@ on:
2727
type: string
2828
package_version:
2929
type: string
30+
outputs:
31+
package_find_links_url:
32+
description: URL for pip --find-links to install built packages
33+
value: ${{ jobs.build_rocm_wheels.outputs.package_find_links_url }}
3034

3135
permissions:
3236
contents: read
3337

3438
run-name: Build portable Linux Python Packages (${{ inputs.artifact_group }}, ${{ inputs.package_version }})
3539

3640
jobs:
37-
build:
41+
build_rocm_wheels:
3842
name: Build Python | ${{ inputs.artifact_group }}
3943
# Note: GitHub-hosted runners run out of disk space for some gpu families
4044
runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }}
45+
permissions:
46+
id-token: write
47+
container:
48+
image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:db2b63f938941dde2abc80b734e64b45b9995a282896d513a0f3525d4591d6cb
49+
options: -v /runner/config:/home/awsconfig/
50+
outputs:
51+
package_find_links_url: ${{ steps.upload.outputs.package_find_links_url }}
4152
env:
42-
BUILD_IMAGE: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:db2b63f938941dde2abc80b734e64b45b9995a282896d513a0f3525d4591d6cb
53+
AWS_SHARED_CREDENTIALS_FILE: /home/awsconfig/credentials.ini
4354
ARTIFACT_RUN_ID: "${{ inputs.artifact_run_id != '' && inputs.artifact_run_id || github.run_id }}"
4455
ARTIFACTS_DIR: "${{ github.workspace }}/artifacts"
4556
PACKAGES_DIR: "${{ github.workspace }}/packages"
46-
MANYLINUX: 1
57+
IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }}
4758

4859
steps:
49-
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
50-
- uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
51-
with:
52-
python-version: '3.12'
60+
- name: Checkout
61+
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
5362

5463
- name: Install Python requirements
55-
run: pip install boto3 packaging piprepo setuptools
64+
run: pip install -r requirements.txt
5665

57-
# Note: we could fetch "all" artifacts if we wanted to include more files
5866
- name: Fetch artifacts
59-
env:
60-
IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }}
6167
run: |
6268
python ./build_tools/fetch_artifacts.py \
63-
--run-github-repo=${{ inputs.artifact_github_repo }} \
64-
--run-id=${{ env.ARTIFACT_RUN_ID }} \
65-
--artifact-group=${{ inputs.artifact_group }} \
66-
--output-dir=${{ env.ARTIFACTS_DIR }} \
67-
_dev_ _lib_ _run_
69+
--run-github-repo="${{ inputs.artifact_github_repo }}" \
70+
--run-id="${{ env.ARTIFACT_RUN_ID }}" \
71+
--artifact-group="${{ inputs.artifact_group }}" \
72+
--output-dir="${{ env.ARTIFACTS_DIR }}"
6873
6974
- name: Build Python packages
7075
run: |
71-
./build_tools/linux_portable_build.py \
72-
--image=${{ env.BUILD_IMAGE }} \
73-
--output-dir=${{ env.PACKAGES_DIR }} \
74-
--artifact-dir=${{ env.ARTIFACTS_DIR }} \
75-
--build-python-only \
76-
-- \
77-
"--version=${{ inputs.package_version }}"
78-
79-
- name: Inspect Python packages
80-
run: |
81-
ls -la "${{ env.PACKAGES_DIR }}"
76+
python ./build_tools/build_python_packages.py \
77+
--artifact-dir="${{ env.ARTIFACTS_DIR }}" \
78+
--dest-dir="${{ env.PACKAGES_DIR }}" \
79+
--version="${{ inputs.package_version }}"
8280
83-
# TODO(#1559): Sanity check (Linux can't find the directories, maybe Docker issues?)
84-
85-
# - name: Sanity check Python packages
86-
# run: |
87-
# piprepo build "${{ env.PACKAGES_DIR }}/dist"
88-
# pip install rocm[devel]==${{ inputs.package_version }} \
89-
# --extra-index-url "${{ env.PACKAGES_DIR }}/dist/simple/"
90-
# rocm-sdk test
81+
- name: Configure AWS Credentials for non-forked repos
82+
if: ${{ !github.event.pull_request.head.repo.fork }}
83+
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0
84+
with:
85+
aws-region: us-east-2
86+
role-to-assume: arn:aws:iam::692859939525:role/therock-ci
9187

92-
# TODO(#1559): upload packages to artifacts S3 bucket and/or a dedicated Python packages bucket
88+
# NOTE: we use `github.run_id` and NOT `env.ARTIFACT_RUN_ID` here!
89+
# This ensures that if they are different we _download_ artifacts from the
90+
# input run's subdirectory and _upload_ to our current run's subdirectory.
91+
- name: Upload Python packages
92+
id: upload
93+
run: |
94+
python build_tools/github_actions/upload_python_packages.py \
95+
--input-packages-dir="${{ env.PACKAGES_DIR }}" \
96+
--artifact-group="${{ inputs.artifact_group }}" \
97+
--run-id="${{ github.run_id }}"

.github/workflows/build_portable_linux_pytorch_wheels.yml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ jobs:
126126
# packages, for example via `pip install torch==${TORCH_VERSION}`.
127127
torch_version: ${{ steps.build-pytorch-wheels.outputs.torch_version }}
128128
torchaudio_version: ${{ steps.build-pytorch-wheels.outputs.torchaudio_version }}
129+
apex_version: ${{ steps.build-pytorch-wheels.outputs.apex_version }}
129130
torchvision_version: ${{ steps.build-pytorch-wheels.outputs.torchvision_version }}
130131
triton_version: ${{ steps.build-pytorch-wheels.outputs.triton_version }}
131132
steps:
@@ -160,6 +161,7 @@ jobs:
160161
run: |
161162
./external-builds/pytorch/pytorch_torch_repo.py checkout --repo-hashtag nightly
162163
./external-builds/pytorch/pytorch_audio_repo.py checkout --repo-hashtag nightly
164+
./external-builds/pytorch/pytorch_apex_repo.py checkout --repo-hashtag master
163165
./external-builds/pytorch/pytorch_vision_repo.py checkout --repo-hashtag nightly
164166
./external-builds/pytorch/pytorch_triton_repo.py checkout
165167
@@ -169,6 +171,7 @@ jobs:
169171
run: |
170172
./external-builds/pytorch/pytorch_torch_repo.py checkout --gitrepo-origin https://github.com/ROCm/pytorch.git --repo-hashtag ${{ inputs.pytorch_git_ref }}
171173
./external-builds/pytorch/pytorch_audio_repo.py checkout --require-related-commit
174+
./external-builds/pytorch/pytorch_apex_repo.py checkout --require-related-commit
172175
./external-builds/pytorch/pytorch_vision_repo.py checkout --require-related-commit
173176
./external-builds/pytorch/pytorch_triton_repo.py checkout
174177
@@ -201,7 +204,7 @@ jobs:
201204
202205
- name: Configure AWS Credentials
203206
if: always()
204-
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 # v5.1.1
207+
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0
205208
with:
206209
aws-region: us-east-2
207210
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}
@@ -267,6 +270,7 @@ jobs:
267270
env:
268271
S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python"
269272
CP_VERSION: "${{ needs.build_pytorch_wheels.outputs.cp_version }}"
273+
APEX_VERSION: "${{ needs.build_pytorch_wheels.outputs.apex_version }}"
270274
TORCH_VERSION: "${{ needs.build_pytorch_wheels.outputs.torch_version }}"
271275
TORCHAUDIO_VERSION: "${{ needs.build_pytorch_wheels.outputs.torchaudio_version }}"
272276
TORCHVISION_VERSION: "${{ needs.build_pytorch_wheels.outputs.torchvision_version }}"
@@ -281,7 +285,7 @@ jobs:
281285

282286
- name: Configure AWS Credentials
283287
if: always()
284-
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708 # v5.1.1
288+
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0
285289
with:
286290
aws-region: us-east-2
287291
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}
@@ -306,7 +310,8 @@ jobs:
306310
--include "torch-${TORCH_VERSION}-${CP_VERSION}-linux_x86_64.whl" \
307311
--include "torchaudio-${TORCHAUDIO_VERSION}-${CP_VERSION}-linux_x86_64.whl" \
308312
--include "torchvision-${TORCHVISION_VERSION}-${CP_VERSION}-linux_x86_64.whl" \
309-
--include "triton-${TRITON_VERSION}-${CP_VERSION}-linux_x86_64.whl"
313+
--include "triton-${TRITON_VERSION}-${CP_VERSION}-linux_x86_64.whl" \
314+
--include "apex-${APEX_VERSION}-${CP_VERSION}-linux_x86_64.whl"
310315
311316
- name: (Re-)Generate Python package release index
312317
if: ${{ env.upload == 'true' }}

0 commit comments

Comments
 (0)