Open
Description
Description
While trying to construct a MWE for a crash when combining shard_map
with explicit sharding meshes, I stumbled upon a 'hard crash'.
Running the script below will crash with the stack trace reported below.
I would not be surprised if the use-case is currently unsupported, but I think it would be useful to support.
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, AxisType
from jax.experimental.shard_map import shard_map
from jax.experimental.shard import reshard
# Setup: 2 CPU devices
jax.config.update("jax_num_cpu_devices", 2)
devices = np.array(jax.devices())
mesh = jax.make_mesh((len(jax.devices()),),("i"), axis_types=(AxisType.Explicit,),)
jax.sharding.set_mesh(mesh) # Set this as the default mesh for jax.
# Define a simple function: takes w, x -> outputs batch of scalars
def simple_func(w, x):
# w: (3,), x: (batch, 3)
return jnp.sum(w * x, axis=-1)
# Make inputs
w = jnp.array([1.0, 2.0, 3.0]) # weights, size (3,)
x = jnp.ones((4, 3)) # batch of 4 vectors, shape (4, 3)
# Setup sharding
replicated_w = reshard(w, P()) # replicated
sharded_x = reshard(x, P("i", None)) # replicated
# --- Evaluate normally ---
out = simple_func(replicated_w, sharded_x)
print("Simple call works:", out)
# --- Try with shard_map ---
out = shard_map(
simple_func,
mesh,
in_specs=(P(None), P("i", None),), # shard x on axis 0
out_specs=P("i")
)(w, x)
print("Shard map works:", out)
try:
# Define shard_map over x, capturing w
shard_out = shard_map(
lambda xi: simple_func(w, xi),
mesh,
in_specs=(P("i", None),), # shard x on axis 0
out_specs=P("i")
)(x)
print("shard_map call works:", shard_out)
except Exception as e:
print("shard_map call FAILED:", e)
The stack trace is
.venv ❯ ipython pp.py
Simple call works: [6. 6. 6. 6.]
Shard map works: [6. 6. 6. 6.]
F0602 12:25:11.743891 7041633 hlo_sharding.cc:1004] Check failed: !IsManual()
*** Check failure stack trace: ***
@ 0x15beda88c absl::lts_20230802::log_internal::LogMessage::SendToLog()
@ 0x15beda720 absl::lts_20230802::log_internal::LogMessage::Flush()
@ 0x15bedab60 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
@ 0x15bedab78 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
@ 0x15bd15720 xla::HloSharding::NumTiles()
@ 0x15740a384 xla::spmd::PartitionedHlo::Reshard()
@ 0x157426060 xla::spmd::SpmdPartitioningVisitor::HandleElementwise()
@ 0x15bbbed00 xla::PostOrderDFS<>()
@ 0x15bbbcd74 xla::HloInstruction::Accept<>()
@ 0x155ecec80 xla::HloComputation::Accept<>()
@ 0x1574403f8 xla::spmd::SpmdPartitioningVisitor::DoPartition()
@ 0x157424c4c xla::spmd::SpmdPartitioner::PartitionComputation()
@ 0x157443774 xla::spmd::SpmdPartitioner::Run()
@ 0x157479bac xla::HloPassPipeline::RunHelper()
@ 0x157476fd4 xla::HloPassPipeline::RunPassesInternal<>()
@ 0x1574769a4 xla::HloPassPipeline::Run()
@ 0x156746c14 xla::cpu::CpuCompiler::RunHloPassesThroughLayoutAssn()
@ 0x15674a67c xla::cpu::CpuCompiler::RunHloPasses()
@ 0x15674a848 xla::cpu::CpuCompiler::RunHloPasses()
@ 0x1566ee9bc xla::TfrtCpuClient::CompileInternal()
@ 0x1566ed330 xla::TfrtCpuClient::CompileAndLoad()
@ 0x15a8e9590 xla::ifrt::PjRtLoadedExecutable::Create()
@ 0x15a8e53d0 xla::ifrt::PjRtCompiler::CompileAndLoad()
@ 0x156539ae0 xla::PyClient::CompileAndLoadIfrtProgram()
@ 0x15653a57c xla::PyClient::CompileAndLoad()
@ 0x1565428dc nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
@ 0x15a8bf6c4 nanobind::detail::nb_func_vectorcall_complex()
@ 0x107e2d834 nanobind::detail::nb_bound_method_vectorcall()
@ 0x103f4b2b8 _PyEval_EvalFrameDefault
@ 0x10401b664 PyObject_Vectorcall
@ 0x15a8b7bd4 nanobind::detail::obj_vectorcall()
@ 0x15bf12520 nanobind::detail::api<>::operator()<>()
@ 0x15bf11af0 jax::WeakrefLRUCache::Call()
@ 0x15bf14228 nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
@ 0x15a8bf6c4 nanobind::detail::nb_func_vectorcall_complex()
@ 0x103faea6c _PyObject_Call_Prepend
@ 0x103fae510 slot_tp_call
@ 0x103f4cf40 _PyEval_EvalFrameDefault
@ 0x10401b664 PyObject_Vectorcall
@ 0x1564b1794 jax::(anonymous namespace)::PjitFunction::Call()
@ 0x1564b04b4 PjitFunction_tp_vectorcall
@ 0x103f48850 _PyEval_EvalFrameDefault
@ 0x1041b6004 method_vectorcall.llvm.5380863741279050681
@ 0x103f4ab9c _PyEval_EvalFrameDefault
@ 0x1041b6004 method_vectorcall.llvm.5380863741279050681
@ 0x103f4ab9c _PyEval_EvalFrameDefault
@ 0x10401b664 PyObject_Vectorcall
@ 0x1564b1794 jax::(anonymous namespace)::PjitFunction::Call()
@ 0x1564b04b4 PjitFunction_tp_vectorcall
@ 0x103f48850 _PyEval_EvalFrameDefault
@ 0x1041b6004 method_vectorcall.llvm.5380863741279050681
@ 0x103f4ab9c _PyEval_EvalFrameDefault
@ 0x1041b6004 method_vectorcall.llvm.5380863741279050681
@ 0x103f4ab9c _PyEval_EvalFrameDefault
@ 0x10401b664 PyObject_Vectorcall
@ 0x1564b1280 jax::(anonymous namespace)::PjitFunction::Call()
@ 0x1564b04b4 PjitFunction_tp_vectorcall
@ 0x103f48850 _PyEval_EvalFrameDefault
@ 0x103faea6c _PyObject_Call_Prepend
@ 0x103fae510 slot_tp_call
@ 0x103f489b4 _PyEval_EvalFrameDefault
zsh: abort ipython pp.py
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.6.1
jaxlib: 0.6.1
numpy: 2.2.6
python: 3.13.3 (main, Apr 9 2025, 03:47:57) [Clang 20.1.0 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='mbp-10841385', release='24.5.0', version='Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:49 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T6000', machine='arm64')