1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ # pylint: disable=protected-access
15+ """Setup script for the ROCm JAX runtime Python convenience wheel.
16+
17+ Minimal file whose only responsibility is to expose the correct version and
18+ package metadata for the ROCm plugin. The real build logic lives in the
19+ version module that is generated alongside the package.
20+ """
1421
1522import importlib
1623import os
1724from setuptools import setup
1825from setuptools .dist import Distribution
1926
2027__version__ = None
21- rocm_version = 0 # placeholder
22- project_name = f"jax-rocm{ rocm_version } -plugin"
23- package_name = f"jax_rocm{ rocm_version } _plugin"
28+ rocm_version = 0 # placeholder # pylint: disable=invalid-name
29+ project_name = f"jax-rocm{ rocm_version } -plugin" # pylint: disable=invalid-name
30+ package_name = f"jax_rocm{ rocm_version } _plugin" # pylint: disable=invalid-name
2431
2532# Extract ROCm version from the `ROCM_PATH` environment variable.
26- default_rocm_path = "/opt/rocm"
27- rocm_path = os .getenv ("ROCM_PATH" , default_rocm_path )
28- rocm_detected_version = rocm_path .split ('-' )[- 1 ] if '-' in rocm_path else "unknown"
33+ DEFAULT_ROCM_PATH = "/opt/rocm"
34+ rocm_path = os .getenv ("ROCM_PATH" , DEFAULT_ROCM_PATH )
35+ rocm_detected_version = rocm_path .split ("-" )[- 1 ] if "-" in rocm_path else "7.0"
36+
2937
3038def load_version_module (pkg_path ):
31- spec = importlib .util .spec_from_file_location (
32- 'version' , os .path .join (pkg_path , 'version.py' ))
33- module = importlib .util .module_from_spec (spec )
34- spec .loader .exec_module (module )
35- return module
39+ """Dynamically import and return the version helper module for the package."""
40+ spec = importlib .util .spec_from_file_location (
41+ "version" , os .path .join (pkg_path , "version.py" )
42+ )
43+ module = importlib .util .module_from_spec (spec )
44+ spec .loader .exec_module (module ) # type: ignore[attr-defined]
45+ return module
46+
3647
3748_version_module = load_version_module (package_name )
38- __version__ = _version_module ._get_version_for_build ()
39- _cmdclass = _version_module ._get_cmdclass (package_name )
49+ __version__ = (
50+ _version_module ._get_version_for_build ()
51+ ) # protected helper from generated module
52+ _cmdclass = _version_module ._get_cmdclass (
53+ package_name
54+ ) # protected helper from generated module
55+
4056
4157class BinaryDistribution (Distribution ):
42- """This class makes 'bdist_wheel' include an ABI tag on the wheel."""
58+ """This class makes 'bdist_wheel' include an ABI tag on the wheel."""
59+
60+ def has_ext_modules (self ): # type: ignore[override]
61+ """Return True to force wheel build to include an ABI tag."""
62+ return True
4363
44- def has_ext_modules (self ):
45- return True
4664
4765setup (
4866 name = project_name ,
@@ -51,18 +69,16 @@ def has_ext_modules(self):
5169 description = f"JAX Plugin for AMD GPUs (ROCm:{ rocm_detected_version } )" ,
5270 long_description = "" ,
5371 long_description_content_type = "text/markdown" ,
54- author = "Ruturaj4 " ,
55- author_email = "Ruturaj.Vaidya @amd.com" ,
72+ author = "GulsumGA-AMD " ,
73+ author_email = "Gulsum.GudukbayAkbulut @amd.com" ,
5674 packages = [package_name ],
57- python_requires = ">=3.9 " ,
75+ python_requires = ">=3.10 " ,
5876 install_requires = [f"jax-rocm{ rocm_version } -pjrt=={ __version__ } " ],
59- url = "https://github.com/jax-ml/ jax" ,
77+ url = "https://github.com/ROCm/rocm- jax" ,
6078 license = "Apache-2.0" ,
6179 classifiers = [
62- "Development Status :: 3 - Alpha" ,
63- "Programming Language :: Python :: 3.9" ,
80+ "Development Status :: 5 - Production/Stable" ,
6481 "Programming Language :: Python :: 3.10" ,
65- "Programming Language :: Python :: 3.11" ,
6682 "Programming Language :: Python :: 3.12" ,
6783 ],
6884 package_data = {
0 commit comments