Skip to content

Commit 3a2f1c6

Browse files
committed
Prepare for plugin 0.8.2 release
1 parent aef5f89 commit 3a2f1c6

23 files changed

+240
-99
lines changed

.github/workflows/rocm-perf.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,4 +170,4 @@ jobs:
170170
--python-version "${{ needs.build-and-test-jax-perf.outputs.python_version }}" \
171171
--rocm-version "${{ needs.build-and-test-jax-perf.outputs.rocm_version }}" \
172172
--gfx-version gfx90a \
173-
--jax-version 0.8.0
173+
--jax-version 0.8.2

BUILDING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ python3 build/ci_build test $TEST_IMAGE --test-cmd "pytest jax_rocm_plugin/tests
136136
We keep unit tests in the `rocm/jax` repository, and you'll need to clone it
137137
to run the regular JAX unit tests with ROCm,
138138
```shell
139-
git clone --depth 1 --branch rocm-jaxlib-v0.8.0 [email protected]:ROCm/jax.git
139+
git clone --depth 1 --branch rocm-jaxlib-v0.8.2 [email protected]:ROCm/jax.git
140140
# Each release of the ROCm plugin has a corresponding branch. You can find
141141
# more at https://github.com/ROCm/rocm-jax/branches/all?query=rocm-jaxlib
142142

build/ci_build

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def dist_wheels(
114114
rocm_version,
115115
python_versions,
116116
xla_source_dir,
117+
jax_source_dir,
117118
rocm_build_job="",
118119
rocm_build_num="",
119120
therock_path=None,
@@ -137,37 +138,17 @@ def dist_wheels(
137138
xla_path = os.path.realpath(os.path.expanduser(xla_source_dir))
138139
cmd.append("--xla-source-dir=%s" % xla_path)
139140

141+
if jax_source_dir:
142+
jax_path = os.path.realpath(os.path.expanduser(jax_source_dir))
143+
cmd.append("--jax-source-dir=%s" % jax_path)
144+
140145
if rbe:
141146
cmd.append("--rbe")
142147

143148
cmd.append("dist_wheels")
144149
subprocess.check_call(cmd, cwd=jax_plugin_dir)
145150

146151

147-
def _fetch_jax_metadata(xla_path):
148-
cmd = ["git", "rev-parse", "HEAD"]
149-
jax_commit = subprocess.check_output(cmd)
150-
xla_commit = b""
151-
152-
if xla_path:
153-
try:
154-
xla_commit = subprocess.check_output(cmd, cwd=xla_path)
155-
except Exception as ex:
156-
LOG.warning("Exception while retrieving xla_commit: %s" % ex)
157-
158-
cmd = ["python3", "setup.py", "-V"]
159-
env = dict(os.environ)
160-
env["JAX_RELEASE"] = "1"
161-
162-
jax_version = subprocess.check_output(cmd, env=env)
163-
164-
return {
165-
"jax_version": jax_version.decode("utf8").strip(),
166-
"jax_commit": jax_commit.decode("utf8").strip(),
167-
"xla_commit": xla_commit.decode("utf8").strip(),
168-
}
169-
170-
171152
def _apply_filters(docker_filters, dockerfile_basename, docker_dir="docker"):
172153
"""
173154
Collect Dockerfile paths matching a basename prefix and optional substring filters.
@@ -540,7 +521,12 @@ def parse_args():
540521

541522
p.add_argument(
542523
"--xla-source-dir",
543-
help="Path to XLA source to use during jaxlib build, instead of builtin XLA",
524+
help="Path to XLA source to use during plugin and jaxlib build, instead of builtin XLA",
525+
)
526+
527+
p.add_argument(
528+
"--jax-source-dir",
529+
help="Optional JAX source directory. When provided, builds jaxlib wheel and copies to wheelhouse.",
544530
)
545531

546532
p.add_argument(
@@ -623,6 +609,7 @@ def main():
623609
rocm_version=args.rocm_version,
624610
python_versions=args.python_versions,
625611
xla_source_dir=args.xla_source_dir,
612+
jax_source_dir=args.jax_source_dir,
626613
rocm_build_job=args.rocm_build_job,
627614
rocm_build_num=args.rocm_build_num,
628615
therock_path=args.therock_path,

build/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
jaxlib==0.8.0
2-
jax==0.8.0
1+
jaxlib==0.8.2
2+
jax==0.8.2

ci/Dockerfile.maxtext

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ WORKDIR /maxtext
1515
# Explicitly install jax,jaxlib to avoid pip pulling a newer version (e.g. 0.8.1)
1616
RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
1717
pip install -r requirements.txt && \
18-
pip install jax==0.8.0 jaxlib==0.8.0 && pip freeze
18+
pip install jax==0.8.2 jaxlib==0.8.2 && pip freeze

ci/jax_rbe/pr_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ python3 build/build.py build --wheels=jax-rocm-plugin --configure_only --python_
4242
--config=rocm_rbe \
4343
--noremote_accept_cached \
4444
--//jax:build_jaxlib=false \
45-
--action_env=TF_ROCM_AMDGPU_TARGETS="gfx906,gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" \
45+
--action_env=TF_ROCM_AMDGPU_TARGETS="gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" \
4646
--test_verbose_timeout_warnings \
4747
--test_output=errors \
4848
//tests:core_test_gpu \

docker/Dockerfile.base-ubu22

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ FROM ubuntu:22.04
44
# The Python version to use to build JAX ROCm plugin and pjrt.
55
ARG PY_VERSION=3.11.13
66
# The list of target devices to be supported by the JAX ROCm plugin and pjrt.
7-
ARG GPU_DEVICE_TARGETS="gfx906 gfx908 gfx90a gfx942 gfx950 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
7+
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx942 gfx950 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
88
# The ROCm version to be used inside the container.
99
ARG ROCM_VERSION
1010
# The installation path for ROCm.

docker/Dockerfile.base-ubu24

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ FROM ubuntu:24.04
22

33
### Container Build Arguments:
44
# The list of target devices to be supported by the JAX ROCm plugin and pjrt.
5-
ARG GPU_DEVICE_TARGETS="gfx906 gfx908 gfx90a gfx942 gfx950 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
5+
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx942 gfx950 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
66
# The ROCm version to be used inside the container.
77
ARG ROCM_VERSION
88
# The installation path for ROCm.

docker/Dockerfile.jax-ubu22

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,5 @@ LABEL com.amdgpu.jax_version="$JAX_VERSION" \
5656
RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
5757
--mount=type=bind,source=wheelhouse,target=/wheelhouse \
5858
ls -lah /wheelhouse && \
59-
/usr/local/bin/pip3 install -f /wheelhouse --no-deps --no-index "jax_rocm${PLUGIN_NAMESPACE}_plugin" "jax_rocm${PLUGIN_NAMESPACE}_pjrt"
59+
/usr/local/bin/pip3 install -f /wheelhouse --no-deps --no-index "jax_rocm${PLUGIN_NAMESPACE}_plugin" "jax_rocm${PLUGIN_NAMESPACE}_pjrt" && \
60+
/usr/local/bin/pip3 install -f /wheelhouse --no-deps --no-index --force-reinstall "jaxlib"

docker/Dockerfile.jax-ubu24

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,5 @@ LABEL com.amdgpu.jax_version="$JAX_VERSION" \
5555
RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
5656
--mount=type=bind,source=wheelhouse,target=/wheelhouse \
5757
ls -lah /wheelhouse && \
58-
pip3 install -f /wheelhouse --no-deps --no-index "jax_rocm${PLUGIN_NAMESPACE}_plugin" "jax_rocm${PLUGIN_NAMESPACE}_pjrt"
58+
pip3 install -f /wheelhouse --no-deps --no-index "jax_rocm${PLUGIN_NAMESPACE}_plugin" "jax_rocm${PLUGIN_NAMESPACE}_pjrt" && \
59+
pip3 install -f /wheelhouse --no-deps --no-index --force-reinstall "jaxlib"

0 commit comments

Comments
 (0)