Skip to content

scatter fails with dynamic slicing on pinned_host #34565

@mathieu-reymond

Description

@mathieu-reymond

Hi,

I encounter an issue when updating an array in pinned memory. I believe it is related to xla, please feel free to correct me if this is not the right place to ask.

The following small script contains an array x on pinned memory, and y on GPU. When updating x with y using a fixed slice ( x = x.at[:5].set(y)) everything works fine:

import jax
import jax.numpy as jnp
import numpy as np

gpu = jax.sharding.SingleDeviceSharding(jax.devices('gpu')[0], 'device')
cpu = jax.sharding.SingleDeviceSharding(jax.devices('gpu')[0], 'pinned_host')

shape = (int(2 * 1024**3 / 4),)
x = np.zeros(shape, dtype=jnp.float32)
x_cpu = jax.device_put(x, cpu)
y_gpu = jax.device_put(jnp.arange(5), gpu)

def update(x, y):
    y = y.astype(x.dtype)
    y = jax.device_put(y, cpu)
    x = x.at[:5].set(y)
    return x

update_jit = jax.jit(update, donate_argnums=(0,))
update_jit(x_cpu, y_gpu)

However, if I change the update function as follows:

def update(x, y):
    i = jnp.arange(5)
    i = jax.device_put(i, cpu)
    y = y.astype(x.dtype)
    y = jax.device_put(y, cpu)
    x = x.at[i].set(y)
    return x

I get the following error trace:

W1106 22:40:22.641304   74111 host_offloader.cc:294] Found an instruction ("scatter.5") which does device compute in host memory space. Converting into host compute. This is likely to have a very slow execution time. If you're using JAX, use device_put() to move the inputs to the device so that computation happens on the device.
W1106 22:40:22.641353   74111 host_offloader.cc:294] Found an instruction ("lt.1") which does device compute in host memory space. Converting into host compute. This is likely to have a very slow execution time. If you're using JAX, use device_put() to move the inputs to the device so that computation happens on the device.
W1106 22:40:22.641367   74111 host_offloader.cc:294] Found an instruction ("select_n.1") which does device compute in host memory space. Converting into host compute. This is likely to have a very slow execution time. If you're using JAX, use device_put() to move the inputs to the device so that computation happens on the device.
F1106 22:40:22.641378   74111 host_offload_utils.cc:225] Check failed: instruction->operand_count() == 1 Expecting instruction lt.1 to have 1 operand, but it has 2.
[symbolize_elf.inc : 379] RAW: Unable to get high fd: rc=0, limit=1024
*** Check failure stack trace: ***
    @     0x7f588bc2d8c4  absl::lts_20250814::log_internal::LogMessage::SendToLog()
    @     0x7f588bc2d846  absl::lts_20250814::log_internal::LogMessage::Flush()
    @     0x7f588b5ca22f  xla::host_offload_utils::GetPredecessors()
    @     0x7f58818287c2  xla::HostOffloader::WalkDownHostMemoryOffloadPaths()
    @     0x7f588182b8ac  xla::HostOffloader::HandleMoveToHostCustomCall()
    @     0x7f5881830798  xla::HostOffloader::ProcessNextMoveToHostInstr()
    @     0x7f588183189a  xla::HostOffloader::Run()
    @     0x7f5882d47fed  xla::HloPassPipeline::RunHelper<>()
    @     0x7f5882d450f6  xla::HloPassPipeline::RunPassesInternal<>()
    @     0x7f5882d44a1e  xla::HloPassPipeline::Run()
    @     0x7f58816ebf41  xla::gpu::GpuCompiler::OptimizeHloPostLayoutAssignment()
    @     0x7f58816d2b64  xla::gpu::NVPTXCompiler::OptimizeHloPostLayoutAssignment()
    @     0x7f58816e51fb  xla::gpu::GpuCompiler::OptimizeHloModule()
    @     0x7f58816eddd7  xla::gpu::GpuCompiler::RunHloPasses()
    @     0x7f58816bf0e8  xla::Service::BuildExecutable()
    @     0x7f588169e4ae  xla::LocalService::CompileExecutables()
    @     0x7f588169b774  xla::LocalClient::Compile()
    @     0x7f58815ff045  xla::PjRtStreamExecutorClient::CompileInternal()
    @     0x7f58816003f0  xla::PjRtStreamExecutorClient::Compile()
    @     0x7f5881600f15  xla::PjRtStreamExecutorClient::CompileAndLoad()
    @     0x7f58815c2d0d  xla::StreamExecutorGpuClient::CompileAndLoad()
    @     0x7f58815b08d7  std::__detail::__variant::__gen_vtable_impl<>::__visit_invoke()
    @     0x7f588159e723  pjrt::PJRT_Client_Compile()
    @     0x7f5951eb9daf  xla::InitializeArgsAndCompile()
    @     0x7f5951eba228  xla::PjRtCApiClient::CompileAndLoad()
    @     0x7f594e8b7fe1  xla::ifrt::PjRtLoadedExecutable::Create()
    @     0x7f594e8b2cb4  xla::ifrt::PjRtCompiler::CompileAndLoad()
    @     0x7f594e76cf2e  jax::PyClient::CompileAndLoadIfrtProgram()
    @     0x7f594e76ec1d  jax::PyClient::CompileAndLoad()
    @     0x7f594e77b4d9  nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
    @     0x7f594e841c81  nanobind::detail::nb_func_vectorcall_complex()
    @     0x7f594e84253c  nanobind::detail::nb_bound_method_vectorcall()
    @     0x7f595a1f635f  PyObject_Vectorcall
    @     0x7f595a382dc8  _PyEval_EvalFrameDefault
    @     0x7f595a1f635f  PyObject_Vectorcall
    @     0x7f594e838d9d  nanobind::detail::obj_vectorcall()
    @     0x7f59585c6eb3  nanobind::detail::api<>::operator()<>()
    @     0x7f59585c5a6f  jax::WeakrefLRUCache::Call()
    @     0x7f59585c90c3  nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
    @     0x7f594e841c81  nanobind::detail::nb_func_vectorcall_complex()
    @     0x7f595a1eeae4  _PyObject_FastCallDictTstate
    @     0x7f595a1eec76  _PyObject_Call_Prepend
    @     0x7f595a285a8d  (unknown)
    @     0x7f595a1ee960  _PyObject_MakeTpCall
    @     0x7f595a382dc8  _PyEval_EvalFrameDefault
    @     0x7f595a1f635f  PyObject_Vectorcall
    @     0x7f5951e42b13  jax::(anonymous namespace)::PjitFunction::Call()
    @     0x7f5951e4101f  PjitFunction_tp_vectorcall
    @     0x7f595a1f635f  PyObject_Vectorcall
    @     0x7f595a382dc8  _PyEval_EvalFrameDefault
    @     0x7f595a38d441  PyEval_EvalCode
    @     0x7f595a3911cf  (unknown)
    @     0x7f595a3912fa  (unknown)
    @     0x7f595a391429  (unknown)
    @     0x7f595a3a79ed  _PyRun_SimpleFileObject
    @     0x7f595a3a85d9  _PyRun_AnyFileObject
    @     0x7f595a3adcef  Py_RunMain
    @     0x7f5959c27675  (unknown)
    @     0x7f5959c27729  __libc_start_main
    @     0x557317d09055  _start

When I look at both variants' jaxpr, I don't see any obvious difference, except for the scatter operation. Here is the working variant:

{ lambda ; a:f32<host>[536870912] b:i32[5]. let
    c:f32[5] = convert_element_type[new_dtype=float32 weak_type=False] b
    d:f32<host>[5] = device_put[
      copy_semantics=(ArrayCopySemantics.REUSE_INPUT,)
      devices=(SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=pinned_host),)
      srcs=(None,)
    ] c
    e:i32[1] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] 0:i32[]
    f:f32<host>[536870912] = scatter[
      dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=())
      indices_are_sorted=True
      mode=GatherScatterMode.FILL_OR_DROP
      unique_indices=True
      update_consts=()
      update_jaxpr=None
    ] a e d
  in (f,) }

and here is the second one (which results in the error trace):

{ lambda ; a:f32<host>[536870912] b:i32[5]. let
    c:i32[5] = iota[dimension=0 dtype=int32 shape=(5,) sharding=None] 
    d:i32<host>[5] = device_put[
      copy_semantics=(ArrayCopySemantics.REUSE_INPUT,)
      devices=(SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=pinned_host),)
      srcs=(None,)
    ] c
    e:f32[5] = convert_element_type[new_dtype=float32 weak_type=False] b
    f:f32<host>[5] = device_put[
      copy_semantics=(ArrayCopySemantics.REUSE_INPUT,)
      devices=(SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=pinned_host),)
      srcs=(None,)
    ] e
    g:bool<host>[5] = lt d 0:i32[]
    h:i32<host>[5] = add d 536870912:i32[]
    i:i32<host>[5] = select_n g d h
    j:i32<host>[5,1] = broadcast_in_dim[
      broadcast_dimensions=(0,)
      shape=(5, 1)
      sharding=None
    ] i
    k:f32<host>[536870912] = scatter[
      dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=())
      indices_are_sorted=False
      mode=GatherScatterMode.FILL_OR_DROP
      unique_indices=False
      update_consts=()
      update_jaxpr=None
    ] a j f
  in (k,) }

I'm afraid I don't fully understand what the error means, so any help would be greatly appreciated

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions