@@ -114,6 +114,7 @@ def dist_wheels(
114114 rocm_version ,
115115 python_versions ,
116116 xla_source_dir ,
117+ jax_source_dir ,
117118 rocm_build_job = "" ,
118119 rocm_build_num = "" ,
119120 therock_path = None ,
@@ -137,37 +138,17 @@ def dist_wheels(
137138 xla_path = os .path .realpath (os .path .expanduser (xla_source_dir ))
138139 cmd .append ("--xla-source-dir=%s" % xla_path )
139140
141+ if jax_source_dir :
142+ jax_path = os .path .realpath (os .path .expanduser (jax_source_dir ))
143+ cmd .append ("--jax-source-dir=%s" % jax_path )
144+
140145 if rbe :
141146 cmd .append ("--rbe" )
142147
143148 cmd .append ("dist_wheels" )
144149 subprocess .check_call (cmd , cwd = jax_plugin_dir )
145150
146151
147- def _fetch_jax_metadata (xla_path ):
148- cmd = ["git" , "rev-parse" , "HEAD" ]
149- jax_commit = subprocess .check_output (cmd )
150- xla_commit = b""
151-
152- if xla_path :
153- try :
154- xla_commit = subprocess .check_output (cmd , cwd = xla_path )
155- except Exception as ex :
156- LOG .warning ("Exception while retrieving xla_commit: %s" % ex )
157-
158- cmd = ["python3" , "setup.py" , "-V" ]
159- env = dict (os .environ )
160- env ["JAX_RELEASE" ] = "1"
161-
162- jax_version = subprocess .check_output (cmd , env = env )
163-
164- return {
165- "jax_version" : jax_version .decode ("utf8" ).strip (),
166- "jax_commit" : jax_commit .decode ("utf8" ).strip (),
167- "xla_commit" : xla_commit .decode ("utf8" ).strip (),
168- }
169-
170-
171152def _apply_filters (docker_filters , dockerfile_basename , docker_dir = "docker" ):
172153 """
173154 Collect Dockerfile paths matching a basename prefix and optional substring filters.
@@ -540,7 +521,12 @@ def parse_args():
540521
541522 p .add_argument (
542523 "--xla-source-dir" ,
543- help = "Path to XLA source to use during jaxlib build, instead of builtin XLA" ,
524+ help = "Path to XLA source to use during plugin and jaxlib build, instead of builtin XLA" ,
525+ )
526+
527+ p .add_argument (
528+ "--jax-source-dir" ,
529+ help = "Optional JAX source directory. When provided, builds jaxlib wheel and copies to wheelhouse." ,
544530 )
545531
546532 p .add_argument (
@@ -623,6 +609,7 @@ def main():
623609 rocm_version = args .rocm_version ,
624610 python_versions = args .python_versions ,
625611 xla_source_dir = args .xla_source_dir ,
612+ jax_source_dir = args .jax_source_dir ,
626613 rocm_build_job = args .rocm_build_job ,
627614 rocm_build_num = args .rocm_build_num ,
628615 therock_path = args .therock_path ,
0 commit comments