Skip to content

Commit 4c6190f

Browse files
Adding rpath for /opt/rocm/libs since librccl.so etc is not found (#135)
* Adding rpath for /opt/rocm/libs since librccl.so etc is not found * fix linting * remove --force-rpath to enable user to set LD_LIBRARY_PATH * black linting fixes * edit pylint config to disable duplicate-code warning, will refactor the whole repo in the future.
1 parent 7abb1a8 commit 4c6190f

File tree

3 files changed

+219
-199
lines changed

3 files changed

+219
-199
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[MESSAGES CONTROL]
22

3-
disable = consider-using-f-string
3+
disable = consider-using-f-string,duplicate-code

jax_rocm_plugin/jaxlib_ext/tools/build_gpu_kernels_wheel.py

Lines changed: 116 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
# run via bazel run as part of the jax cuda plugin build process.
1717

1818
# Most users should not run this script directly; use build.py instead.
19+
# pylint: disable=duplicate-code
20+
21+
"""
22+
Script to build a JAX ROCm kernel plugin wheel. Intended for use via Bazel.
23+
"""
1924

2025
import argparse
2126
import functools
@@ -25,6 +30,7 @@
2530
import subprocess
2631
import tempfile
2732

33+
# pylint: disable=import-error
2834
from bazel_tools.tools.python.runfiles import runfiles
2935
from jaxlib_ext.tools import build_utils
3036

@@ -58,120 +64,125 @@
5864
parser.add_argument(
5965
"--enable-cuda",
6066
default=False,
61-
help="Should we build with CUDA enabled? Requires CUDA and CuDNN.")
67+
help="Should we build with CUDA enabled? Requires CUDA and CuDNN.",
68+
)
6269
parser.add_argument(
63-
"--enable-rocm",
64-
default=False,
65-
help="Should we build with ROCM enabled?")
70+
"--enable-rocm", default=False, help="Should we build with ROCM enabled?"
71+
)
6672
args = parser.parse_args()
6773

6874
r = runfiles.Create()
69-
pyext = "pyd" if build_utils.is_windows() else "so"
75+
PYEXT = "pyd" if build_utils.is_windows() else "so"
7076

7177

72-
def write_setup_cfg(sources_path, cpu):
73-
tag = build_utils.platform_tag(cpu)
74-
with open(sources_path / "setup.cfg", "w") as f:
75-
f.write(f"""[metadata]
78+
def write_setup_cfg(setup_cfg_path, cpu):
79+
"""Write setup.cfg with platform tag."""
80+
tag = build_utils.platform_tag(cpu)
81+
with open(setup_cfg_path / "setup.cfg", "w", encoding="utf-8") as cfg_file:
82+
cfg_file.write(
83+
f"""[metadata]
7684
license_files = LICENSE.txt
7785
7886
[bdist_wheel]
7987
plat_name={tag}
80-
""")
81-
82-
83-
def prepare_wheel_rocm(
84-
sources_path: pathlib.Path, *, cpu, rocm_version
85-
):
86-
"""Assembles a source tree for the rocm kernel wheel in `sources_path`."""
87-
copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r)
88-
89-
copy_runfiles(
90-
"__main__/jax_plugins/rocm/plugin_pyproject.toml",
91-
dst_dir=sources_path,
92-
dst_filename="pyproject.toml",
93-
)
94-
copy_runfiles(
95-
"__main__/jax_plugins/rocm/plugin_setup.py",
96-
dst_dir=sources_path,
97-
dst_filename="setup.py",
98-
)
99-
build_utils.update_setup_with_rocm_version(sources_path, rocm_version)
100-
write_setup_cfg(sources_path, cpu)
101-
102-
plugin_dir = sources_path / f"jax_rocm{rocm_version}_plugin"
103-
copy_runfiles(
104-
dst_dir=plugin_dir,
105-
src_files=[
106-
f"jax/jaxlib/rocm/_linalg.{pyext}",
107-
f"jax/jaxlib/rocm/_prng.{pyext}",
108-
f"jax/jaxlib/rocm/_solver.{pyext}",
109-
f"jax/jaxlib/rocm/_sparse.{pyext}",
110-
f"jax/jaxlib/rocm/_hybrid.{pyext}",
111-
f"jax/jaxlib/rocm/_rnn.{pyext}",
112-
f"jax/jaxlib/rocm/_triton.{pyext}",
113-
f"jax/jaxlib/rocm/rocm_plugin_extension.{pyext}",
114-
"jax/jaxlib/version.py",
115-
],
116-
)
117-
118-
# NOTE(mrodden): this is a hack to change/set rpath values
119-
# in the shared objects that are produced by the bazel build
120-
# before they get pulled into the wheel build process.
121-
# we have to do this change here because setting rpath
122-
# using bazel requires the rpath to be valid during the build
123-
# which won't be correct until we make changes to
124-
# the xla/tsl/jax plugin build
125-
126-
try:
127-
subprocess.check_output(["which", "patchelf"])
128-
except subprocess.CalledProcessError as ex:
129-
mesg = (
130-
"rocm plugin and kernel wheel builds require patchelf. "
131-
"please install 'patchelf' and run again"
88+
"""
89+
)
90+
91+
92+
def prepare_wheel_rocm(rocm_sources_path: pathlib.Path, *, cpu, rocm_version):
93+
"""Assembles a source tree for the rocm kernel wheel in `rocm_sources_path`."""
94+
copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r)
95+
96+
copy_runfiles(
97+
"__main__/jax_plugins/rocm/plugin_pyproject.toml",
98+
dst_dir=rocm_sources_path,
99+
dst_filename="pyproject.toml",
132100
)
133-
raise Exception(mesg) from ex
134-
135-
files = [
136-
f"_linalg.{pyext}",
137-
f"_prng.{pyext}",
138-
f"_solver.{pyext}",
139-
f"_sparse.{pyext}",
140-
f"_hybrid.{pyext}",
141-
f"_rnn.{pyext}",
142-
f"_triton.{pyext}",
143-
f"rocm_plugin_extension.{pyext}",
144-
]
145-
runpath = '$ORIGIN/../rocm/lib:$ORIGIN/../../rocm/lib'
146-
# patchelf --force-rpath --set-rpath $RUNPATH $so
147-
for f in files:
148-
so_path = os.path.join(plugin_dir, f)
149-
fix_perms = False
150-
perms = os.stat(so_path).st_mode
151-
if not perms & stat.S_IWUSR:
152-
fix_perms = True
153-
os.chmod(so_path, perms | stat.S_IWUSR)
154-
subprocess.check_call(["patchelf", "--force-rpath", "--set-rpath", runpath, so_path])
155-
if fix_perms:
156-
os.chmod(so_path, perms)
157-
158-
tmpdir = tempfile.TemporaryDirectory(prefix="jax_rocm_plugin")
159-
sources_path = tmpdir.name
160-
try:
161-
os.makedirs(args.output_path, exist_ok=True)
162-
prepare_wheel_rocm(
163-
pathlib.Path(sources_path), cpu=args.cpu, rocm_version=args.platform_version
164-
)
165-
package_name = f"jax rocm{args.platform_version} plugin"
166-
if args.editable:
167-
build_utils.build_editable(sources_path, args.output_path, package_name)
168-
else:
169-
git_hash = build_utils.get_githash(args.jaxlib_git_hash)
170-
build_utils.build_wheel(
171-
sources_path,
172-
args.output_path,
173-
package_name,
174-
git_hash=git_hash,
101+
copy_runfiles(
102+
"__main__/jax_plugins/rocm/plugin_setup.py",
103+
dst_dir=rocm_sources_path,
104+
dst_filename="setup.py",
175105
)
176-
finally:
177-
tmpdir.cleanup()
106+
build_utils.update_setup_with_rocm_version(rocm_sources_path, rocm_version)
107+
write_setup_cfg(rocm_sources_path, cpu)
108+
109+
plugin_dir = rocm_sources_path / f"jax_rocm{rocm_version}_plugin"
110+
copy_runfiles(
111+
dst_dir=plugin_dir,
112+
src_files=[
113+
f"jax/jaxlib/rocm/_linalg.{PYEXT}",
114+
f"jax/jaxlib/rocm/_prng.{PYEXT}",
115+
f"jax/jaxlib/rocm/_solver.{PYEXT}",
116+
f"jax/jaxlib/rocm/_sparse.{PYEXT}",
117+
f"jax/jaxlib/rocm/_hybrid.{PYEXT}",
118+
f"jax/jaxlib/rocm/_rnn.{PYEXT}",
119+
f"jax/jaxlib/rocm/_triton.{PYEXT}",
120+
f"jax/jaxlib/rocm/rocm_plugin_extension.{PYEXT}",
121+
"jax/jaxlib/version.py",
122+
],
123+
)
124+
125+
# NOTE(mrodden): this is a hack to change/set rpath values
126+
# in the shared objects that are produced by the bazel build
127+
# before they get pulled into the wheel build process.
128+
# we have to do this change here because setting rpath
129+
# using bazel requires the rpath to be valid during the build
130+
# which won't be correct until we make changes to
131+
# the xla/tsl/jax plugin build
132+
133+
try:
134+
subprocess.check_output(["which", "patchelf"])
135+
except subprocess.CalledProcessError as ex:
136+
mesg = (
137+
"rocm plugin and kernel wheel builds require patchelf. "
138+
"please install 'patchelf' and run again"
139+
)
140+
raise RuntimeError(mesg) from ex
141+
142+
files = [
143+
f"_linalg.{PYEXT}",
144+
f"_prng.{PYEXT}",
145+
f"_solver.{PYEXT}",
146+
f"_sparse.{PYEXT}",
147+
f"_hybrid.{PYEXT}",
148+
f"_rnn.{PYEXT}",
149+
f"_triton.{PYEXT}",
150+
f"rocm_plugin_extension.{PYEXT}",
151+
]
152+
runpath = "$ORIGIN/../rocm/lib:$ORIGIN/../../rocm/lib:/opt/rocm/lib"
153+
# patchelf --set-rpath $RUNPATH $so
154+
for fname in files:
155+
so_path = os.path.join(plugin_dir, fname)
156+
fix_perms = False
157+
perms = os.stat(so_path).st_mode
158+
if not perms & stat.S_IWUSR:
159+
fix_perms = True
160+
os.chmod(so_path, perms | stat.S_IWUSR)
161+
subprocess.check_call(["patchelf", "--set-rpath", runpath, so_path])
162+
if fix_perms:
163+
os.chmod(so_path, perms)
164+
165+
166+
def main():
167+
"""Main entry point for building the ROCm kernel plugin wheel."""
168+
with tempfile.TemporaryDirectory(prefix="jax_rocm_plugin") as tmpdir:
169+
sources_path = tmpdir
170+
os.makedirs(args.output_path, exist_ok=True)
171+
prepare_wheel_rocm(
172+
pathlib.Path(sources_path), cpu=args.cpu, rocm_version=args.platform_version
173+
)
174+
package_name = f"jax rocm{args.platform_version} plugin"
175+
if args.editable:
176+
build_utils.build_editable(sources_path, args.output_path, package_name)
177+
else:
178+
git_hash = build_utils.get_githash(args.jaxlib_git_hash)
179+
build_utils.build_wheel(
180+
sources_path,
181+
args.output_path,
182+
package_name,
183+
git_hash=git_hash,
184+
)
185+
186+
187+
if __name__ == "__main__":
188+
main()

0 commit comments

Comments
 (0)