Skip to content

[sharding-in-types+shard_map] Support closing over explicit meshes when using shard map #29162

Open
@PhilipVinc

Description

@PhilipVinc

Description

While trying to construct a MWE for a crash when combining shard_map with explicit sharding meshes, I stumbled upon a 'hard crash'.

Running the script below will crash with the stack trace reported below.
I would not be surprised if the use-case is currently unsupported, but I think it would be useful to support.

import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, AxisType
from jax.experimental.shard_map import shard_map
from jax.experimental.shard import reshard

# Setup: 2 CPU devices
jax.config.update("jax_num_cpu_devices", 2)
devices = np.array(jax.devices())
mesh = jax.make_mesh((len(jax.devices()),),("i"), axis_types=(AxisType.Explicit,),)
jax.sharding.set_mesh(mesh) # Set this as the default mesh for jax.

# Define a simple function: takes w, x -> outputs batch of scalars
def simple_func(w, x):
    # w: (3,), x: (batch, 3)
    return jnp.sum(w * x, axis=-1)

# Make inputs
w = jnp.array([1.0, 2.0, 3.0])  # weights, size (3,)
x = jnp.ones((4, 3))            # batch of 4 vectors, shape (4, 3)

# Setup sharding
replicated_w = reshard(w, P())           # replicated
sharded_x = reshard(x, P("i", None))           # replicated

# --- Evaluate normally ---
out = simple_func(replicated_w, sharded_x)
print("Simple call works:", out)

# --- Try with shard_map ---
out = shard_map(
        simple_func,
        mesh,
        in_specs=(P(None), P("i", None),),  # shard x on axis 0
        out_specs=P("i")
    )(w, x)
print("Shard map works:", out)

try:
    # Define shard_map over x, capturing w
    shard_out = shard_map(
        lambda xi: simple_func(w, xi),
        mesh,
        in_specs=(P("i", None),),  # shard x on axis 0
        out_specs=P("i")
    )(x)
    print("shard_map call works:", shard_out)
except Exception as e:
    print("shard_map call FAILED:", e)

The stack trace is

.venvipython pp.py
Simple call works: [6. 6. 6. 6.]
Shard map works: [6. 6. 6. 6.]
F0602 12:25:11.743891 7041633 hlo_sharding.cc:1004] Check failed: !IsManual() 
*** Check failure stack trace: ***
    @        0x15beda88c  absl::lts_20230802::log_internal::LogMessage::SendToLog()
    @        0x15beda720  absl::lts_20230802::log_internal::LogMessage::Flush()
    @        0x15bedab60  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @        0x15bedab78  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @        0x15bd15720  xla::HloSharding::NumTiles()
    @        0x15740a384  xla::spmd::PartitionedHlo::Reshard()
    @        0x157426060  xla::spmd::SpmdPartitioningVisitor::HandleElementwise()
    @        0x15bbbed00  xla::PostOrderDFS<>()
    @        0x15bbbcd74  xla::HloInstruction::Accept<>()
    @        0x155ecec80  xla::HloComputation::Accept<>()
    @        0x1574403f8  xla::spmd::SpmdPartitioningVisitor::DoPartition()
    @        0x157424c4c  xla::spmd::SpmdPartitioner::PartitionComputation()
    @        0x157443774  xla::spmd::SpmdPartitioner::Run()
    @        0x157479bac  xla::HloPassPipeline::RunHelper()
    @        0x157476fd4  xla::HloPassPipeline::RunPassesInternal<>()
    @        0x1574769a4  xla::HloPassPipeline::Run()
    @        0x156746c14  xla::cpu::CpuCompiler::RunHloPassesThroughLayoutAssn()
    @        0x15674a67c  xla::cpu::CpuCompiler::RunHloPasses()
    @        0x15674a848  xla::cpu::CpuCompiler::RunHloPasses()
    @        0x1566ee9bc  xla::TfrtCpuClient::CompileInternal()
    @        0x1566ed330  xla::TfrtCpuClient::CompileAndLoad()
    @        0x15a8e9590  xla::ifrt::PjRtLoadedExecutable::Create()
    @        0x15a8e53d0  xla::ifrt::PjRtCompiler::CompileAndLoad()
    @        0x156539ae0  xla::PyClient::CompileAndLoadIfrtProgram()
    @        0x15653a57c  xla::PyClient::CompileAndLoad()
    @        0x1565428dc  nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
    @        0x15a8bf6c4  nanobind::detail::nb_func_vectorcall_complex()
    @        0x107e2d834  nanobind::detail::nb_bound_method_vectorcall()
    @        0x103f4b2b8  _PyEval_EvalFrameDefault
    @        0x10401b664  PyObject_Vectorcall
    @        0x15a8b7bd4  nanobind::detail::obj_vectorcall()
    @        0x15bf12520  nanobind::detail::api<>::operator()<>()
    @        0x15bf11af0  jax::WeakrefLRUCache::Call()
    @        0x15bf14228  nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
    @        0x15a8bf6c4  nanobind::detail::nb_func_vectorcall_complex()
    @        0x103faea6c  _PyObject_Call_Prepend
    @        0x103fae510  slot_tp_call
    @        0x103f4cf40  _PyEval_EvalFrameDefault
    @        0x10401b664  PyObject_Vectorcall
    @        0x1564b1794  jax::(anonymous namespace)::PjitFunction::Call()
    @        0x1564b04b4  PjitFunction_tp_vectorcall
    @        0x103f48850  _PyEval_EvalFrameDefault
    @        0x1041b6004  method_vectorcall.llvm.5380863741279050681
    @        0x103f4ab9c  _PyEval_EvalFrameDefault
    @        0x1041b6004  method_vectorcall.llvm.5380863741279050681
    @        0x103f4ab9c  _PyEval_EvalFrameDefault
    @        0x10401b664  PyObject_Vectorcall
    @        0x1564b1794  jax::(anonymous namespace)::PjitFunction::Call()
    @        0x1564b04b4  PjitFunction_tp_vectorcall
    @        0x103f48850  _PyEval_EvalFrameDefault
    @        0x1041b6004  method_vectorcall.llvm.5380863741279050681
    @        0x103f4ab9c  _PyEval_EvalFrameDefault
    @        0x1041b6004  method_vectorcall.llvm.5380863741279050681
    @        0x103f4ab9c  _PyEval_EvalFrameDefault
    @        0x10401b664  PyObject_Vectorcall
    @        0x1564b1280  jax::(anonymous namespace)::PjitFunction::Call()
    @        0x1564b04b4  PjitFunction_tp_vectorcall
    @        0x103f48850  _PyEval_EvalFrameDefault
    @        0x103faea6c  _PyObject_Call_Prepend
    @        0x103fae510  slot_tp_call
    @        0x103f489b4  _PyEval_EvalFrameDefault
zsh: abort      ipython pp.py

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.6.1
jaxlib: 0.6.1
numpy:  2.2.6
python: 3.13.3 (main, Apr  9 2025, 03:47:57) [Clang 20.1.0 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='mbp-10841385', release='24.5.0', version='Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:49 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T6000', machine='arm64')

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions