Skip to content

Commit 53703c0

Browse files
adding last polishes to wheels (#136)
* adding last polishes to wheels * fix linting
1 parent 4c6190f commit 53703c0

File tree

2 files changed

+65
-44
lines changed

2 files changed

+65
-44
lines changed

jax_rocm_plugin/jax_plugins/rocm/plugin_setup.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,38 +11,56 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
# pylint: disable=protected-access
15+
"""Setup script for the ROCm JAX runtime Python convenience wheel.
16+
17+
Minimal file whose only responsibility is to expose the correct version and
18+
package metadata for the ROCm plugin. The real build logic lives in the
19+
version module that is generated alongside the package.
20+
"""
1421

1522
import importlib
1623
import os
1724
from setuptools import setup
1825
from setuptools.dist import Distribution
1926

2027
__version__ = None
21-
rocm_version = 0 # placeholder
22-
project_name = f"jax-rocm{rocm_version}-plugin"
23-
package_name = f"jax_rocm{rocm_version}_plugin"
28+
rocm_version = 0 # placeholder # pylint: disable=invalid-name
29+
project_name = f"jax-rocm{rocm_version}-plugin" # pylint: disable=invalid-name
30+
package_name = f"jax_rocm{rocm_version}_plugin" # pylint: disable=invalid-name
2431

2532
# Extract ROCm version from the `ROCM_PATH` environment variable.
26-
default_rocm_path = "/opt/rocm"
27-
rocm_path = os.getenv("ROCM_PATH", default_rocm_path)
28-
rocm_detected_version = rocm_path.split('-')[-1] if '-' in rocm_path else "unknown"
33+
DEFAULT_ROCM_PATH = "/opt/rocm"
34+
rocm_path = os.getenv("ROCM_PATH", DEFAULT_ROCM_PATH)
35+
rocm_detected_version = rocm_path.split("-")[-1] if "-" in rocm_path else "7.0"
36+
2937

3038
def load_version_module(pkg_path):
31-
spec = importlib.util.spec_from_file_location(
32-
'version', os.path.join(pkg_path, 'version.py'))
33-
module = importlib.util.module_from_spec(spec)
34-
spec.loader.exec_module(module)
35-
return module
39+
"""Dynamically import and return the version helper module for the package."""
40+
spec = importlib.util.spec_from_file_location(
41+
"version", os.path.join(pkg_path, "version.py")
42+
)
43+
module = importlib.util.module_from_spec(spec)
44+
spec.loader.exec_module(module) # type: ignore[attr-defined]
45+
return module
46+
3647

3748
_version_module = load_version_module(package_name)
38-
__version__ = _version_module._get_version_for_build()
39-
_cmdclass = _version_module._get_cmdclass(package_name)
49+
__version__ = (
50+
_version_module._get_version_for_build()
51+
) # protected helper from generated module
52+
_cmdclass = _version_module._get_cmdclass(
53+
package_name
54+
) # protected helper from generated module
55+
4056

4157
class BinaryDistribution(Distribution):
42-
"""This class makes 'bdist_wheel' include an ABI tag on the wheel."""
58+
"""This class makes 'bdist_wheel' include an ABI tag on the wheel."""
59+
60+
def has_ext_modules(self): # type: ignore[override]
61+
"""Return True to force wheel build to include an ABI tag."""
62+
return True
4363

44-
def has_ext_modules(self):
45-
return True
4664

4765
setup(
4866
name=project_name,
@@ -51,18 +69,16 @@ def has_ext_modules(self):
5169
description=f"JAX Plugin for AMD GPUs (ROCm:{rocm_detected_version})",
5270
long_description="",
5371
long_description_content_type="text/markdown",
54-
author="Ruturaj4",
55-
author_email="Ruturaj.Vaidya@amd.com",
72+
author="GulsumGA-AMD",
73+
author_email="Gulsum.GudukbayAkbulut@amd.com",
5674
packages=[package_name],
57-
python_requires=">=3.9",
75+
python_requires=">=3.10",
5876
install_requires=[f"jax-rocm{rocm_version}-pjrt=={__version__}"],
59-
url="https://github.com/jax-ml/jax",
77+
url="https://github.com/ROCm/rocm-jax",
6078
license="Apache-2.0",
6179
classifiers=[
62-
"Development Status :: 3 - Alpha",
63-
"Programming Language :: Python :: 3.9",
80+
"Development Status :: 5 - Production/Stable",
6481
"Programming Language :: Python :: 3.10",
65-
"Programming Language :: Python :: 3.11",
6682
"Programming Language :: Python :: 3.12",
6783
],
6884
package_data={

jax_rocm_plugin/pjrt/python/setup.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,52 +11,57 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
# pylint: disable=protected-access
15+
"""Setup script for the ROCm JAX PJRT plugin package.
16+
17+
Holds only lightweight metadata wiring; build/version logic resides in the
18+
generated version helper inside the package hierarchy.
19+
"""
1420

1521
import importlib
1622
import os
1723
from setuptools import setup, find_namespace_packages
1824

1925
__version__ = None
20-
rocm_version = 0 # placeholder
21-
project_name = f"jax-rocm{rocm_version}-pjrt"
22-
package_name = f"jax_plugins.xla_rocm{rocm_version}"
26+
rocm_version = 0 # placeholder (runtime substituted) # pylint: disable=invalid-name
27+
project_name = f"jax-rocm{rocm_version}-pjrt" # pylint: disable=invalid-name
28+
package_name = f"jax_plugins.xla_rocm{rocm_version}" # pylint: disable=invalid-name
2329

2430
# Extract ROCm version from the `ROCM_PATH` environment variable.
25-
default_rocm_path = "/opt/rocm"
26-
rocm_path = os.getenv("ROCM_PATH", default_rocm_path)
27-
rocm_detected_version = rocm_path.split('-')[-1] if '-' in rocm_path else "unknown"
31+
DEFAULT_ROCM_PATH = "/opt/rocm"
32+
rocm_path = os.getenv("ROCM_PATH", DEFAULT_ROCM_PATH)
33+
rocm_detected_version = rocm_path.split("-")[-1] if "-" in rocm_path else "7.0"
34+
2835

2936
def load_version_module(pkg_path):
30-
spec = importlib.util.spec_from_file_location(
31-
'version', os.path.join(pkg_path, 'version.py'))
32-
module = importlib.util.module_from_spec(spec)
33-
spec.loader.exec_module(module)
34-
return module
37+
"""Import and return the package's version helper module dynamically."""
38+
spec = importlib.util.spec_from_file_location(
39+
"version", os.path.join(pkg_path, "version.py")
40+
)
41+
module = importlib.util.module_from_spec(spec)
42+
spec.loader.exec_module(module) # type: ignore[attr-defined]
43+
return module
44+
3545

3646
_version_module = load_version_module(f"jax_plugins/xla_rocm{rocm_version}")
3747
__version__ = _version_module._get_version_for_build()
3848

39-
packages = find_namespace_packages(
40-
include=[
41-
package_name,
42-
f"{package_name}.*",
43-
]
44-
)
49+
packages = find_namespace_packages(include=[package_name, f"{package_name}.*"])
4550

4651
setup(
4752
name=project_name,
4853
version=__version__,
4954
description=f"JAX XLA PJRT Plugin for AMD GPUs (ROCm:{rocm_detected_version})",
5055
long_description="",
5156
long_description_content_type="text/markdown",
52-
author="Ruturaj4",
53-
author_email="Ruturaj.Vaidya@amd.com",
57+
author="GulsumGA-AMD",
58+
author_email="Gulsum.GudukbayAkbulut@amd.com",
5459
packages=packages,
5560
install_requires=[],
56-
url="https://github.com/jax-ml/jax",
61+
url="https://github.com/ROCm/rocm-jax",
5762
license="Apache-2.0",
5863
classifiers=[
59-
"Development Status :: 3 - Alpha",
64+
"Development Status :: 5 - Production/Stable",
6065
"Programming Language :: Python :: 3",
6166
],
6267
package_data={

0 commit comments

Comments
 (0)