-
Notifications
You must be signed in to change notification settings - Fork 696
Open
Description
Hi,
I encounter an issue when updating an array in pinned memory. I believe it is related to xla, please feel free to correct me if this is not the right place to ask.
The following small script contains an array x on pinned memory, and y on GPU. When updating x with y using a fixed slice ( x = x.at[:5].set(y)) everything works fine:
import jax
import jax.numpy as jnp
import numpy as np
gpu = jax.sharding.SingleDeviceSharding(jax.devices('gpu')[0], 'device')
cpu = jax.sharding.SingleDeviceSharding(jax.devices('gpu')[0], 'pinned_host')
shape = (int(2 * 1024**3 / 4),)
x = np.zeros(shape, dtype=jnp.float32)
x_cpu = jax.device_put(x, cpu)
y_gpu = jax.device_put(jnp.arange(5), gpu)
def update(x, y):
y = y.astype(x.dtype)
y = jax.device_put(y, cpu)
x = x.at[:5].set(y)
return x
update_jit = jax.jit(update, donate_argnums=(0,))
update_jit(x_cpu, y_gpu)However, if I change the update function as follows:
def update(x, y):
i = jnp.arange(5)
i = jax.device_put(i, cpu)
y = y.astype(x.dtype)
y = jax.device_put(y, cpu)
x = x.at[i].set(y)
return xI get the following error trace:
W1106 22:40:22.641304 74111 host_offloader.cc:294] Found an instruction ("scatter.5") which does device compute in host memory space. Converting into host compute. This is likely to have a very slow execution time. If you're using JAX, use device_put() to move the inputs to the device so that computation happens on the device.
W1106 22:40:22.641353 74111 host_offloader.cc:294] Found an instruction ("lt.1") which does device compute in host memory space. Converting into host compute. This is likely to have a very slow execution time. If you're using JAX, use device_put() to move the inputs to the device so that computation happens on the device.
W1106 22:40:22.641367 74111 host_offloader.cc:294] Found an instruction ("select_n.1") which does device compute in host memory space. Converting into host compute. This is likely to have a very slow execution time. If you're using JAX, use device_put() to move the inputs to the device so that computation happens on the device.
F1106 22:40:22.641378 74111 host_offload_utils.cc:225] Check failed: instruction->operand_count() == 1 Expecting instruction lt.1 to have 1 operand, but it has 2.
[symbolize_elf.inc : 379] RAW: Unable to get high fd: rc=0, limit=1024
*** Check failure stack trace: ***
@ 0x7f588bc2d8c4 absl::lts_20250814::log_internal::LogMessage::SendToLog()
@ 0x7f588bc2d846 absl::lts_20250814::log_internal::LogMessage::Flush()
@ 0x7f588b5ca22f xla::host_offload_utils::GetPredecessors()
@ 0x7f58818287c2 xla::HostOffloader::WalkDownHostMemoryOffloadPaths()
@ 0x7f588182b8ac xla::HostOffloader::HandleMoveToHostCustomCall()
@ 0x7f5881830798 xla::HostOffloader::ProcessNextMoveToHostInstr()
@ 0x7f588183189a xla::HostOffloader::Run()
@ 0x7f5882d47fed xla::HloPassPipeline::RunHelper<>()
@ 0x7f5882d450f6 xla::HloPassPipeline::RunPassesInternal<>()
@ 0x7f5882d44a1e xla::HloPassPipeline::Run()
@ 0x7f58816ebf41 xla::gpu::GpuCompiler::OptimizeHloPostLayoutAssignment()
@ 0x7f58816d2b64 xla::gpu::NVPTXCompiler::OptimizeHloPostLayoutAssignment()
@ 0x7f58816e51fb xla::gpu::GpuCompiler::OptimizeHloModule()
@ 0x7f58816eddd7 xla::gpu::GpuCompiler::RunHloPasses()
@ 0x7f58816bf0e8 xla::Service::BuildExecutable()
@ 0x7f588169e4ae xla::LocalService::CompileExecutables()
@ 0x7f588169b774 xla::LocalClient::Compile()
@ 0x7f58815ff045 xla::PjRtStreamExecutorClient::CompileInternal()
@ 0x7f58816003f0 xla::PjRtStreamExecutorClient::Compile()
@ 0x7f5881600f15 xla::PjRtStreamExecutorClient::CompileAndLoad()
@ 0x7f58815c2d0d xla::StreamExecutorGpuClient::CompileAndLoad()
@ 0x7f58815b08d7 std::__detail::__variant::__gen_vtable_impl<>::__visit_invoke()
@ 0x7f588159e723 pjrt::PJRT_Client_Compile()
@ 0x7f5951eb9daf xla::InitializeArgsAndCompile()
@ 0x7f5951eba228 xla::PjRtCApiClient::CompileAndLoad()
@ 0x7f594e8b7fe1 xla::ifrt::PjRtLoadedExecutable::Create()
@ 0x7f594e8b2cb4 xla::ifrt::PjRtCompiler::CompileAndLoad()
@ 0x7f594e76cf2e jax::PyClient::CompileAndLoadIfrtProgram()
@ 0x7f594e76ec1d jax::PyClient::CompileAndLoad()
@ 0x7f594e77b4d9 nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
@ 0x7f594e841c81 nanobind::detail::nb_func_vectorcall_complex()
@ 0x7f594e84253c nanobind::detail::nb_bound_method_vectorcall()
@ 0x7f595a1f635f PyObject_Vectorcall
@ 0x7f595a382dc8 _PyEval_EvalFrameDefault
@ 0x7f595a1f635f PyObject_Vectorcall
@ 0x7f594e838d9d nanobind::detail::obj_vectorcall()
@ 0x7f59585c6eb3 nanobind::detail::api<>::operator()<>()
@ 0x7f59585c5a6f jax::WeakrefLRUCache::Call()
@ 0x7f59585c90c3 nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
@ 0x7f594e841c81 nanobind::detail::nb_func_vectorcall_complex()
@ 0x7f595a1eeae4 _PyObject_FastCallDictTstate
@ 0x7f595a1eec76 _PyObject_Call_Prepend
@ 0x7f595a285a8d (unknown)
@ 0x7f595a1ee960 _PyObject_MakeTpCall
@ 0x7f595a382dc8 _PyEval_EvalFrameDefault
@ 0x7f595a1f635f PyObject_Vectorcall
@ 0x7f5951e42b13 jax::(anonymous namespace)::PjitFunction::Call()
@ 0x7f5951e4101f PjitFunction_tp_vectorcall
@ 0x7f595a1f635f PyObject_Vectorcall
@ 0x7f595a382dc8 _PyEval_EvalFrameDefault
@ 0x7f595a38d441 PyEval_EvalCode
@ 0x7f595a3911cf (unknown)
@ 0x7f595a3912fa (unknown)
@ 0x7f595a391429 (unknown)
@ 0x7f595a3a79ed _PyRun_SimpleFileObject
@ 0x7f595a3a85d9 _PyRun_AnyFileObject
@ 0x7f595a3adcef Py_RunMain
@ 0x7f5959c27675 (unknown)
@ 0x7f5959c27729 __libc_start_main
@ 0x557317d09055 _start
When I look at both variants' jaxpr, I don't see any obvious difference, except for the scatter operation. Here is the working variant:
{ lambda ; a:f32<host>[536870912] b:i32[5]. let
c:f32[5] = convert_element_type[new_dtype=float32 weak_type=False] b
d:f32<host>[5] = device_put[
copy_semantics=(ArrayCopySemantics.REUSE_INPUT,)
devices=(SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=pinned_host),)
srcs=(None,)
] c
e:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 0:i32[]
f:f32<host>[536870912] = scatter[
dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=())
indices_are_sorted=True
mode=GatherScatterMode.FILL_OR_DROP
unique_indices=True
update_consts=()
update_jaxpr=None
] a e d
in (f,) }and here is the second one (which results in the error trace):
{ lambda ; a:f32<host>[536870912] b:i32[5]. let
c:i32[5] = iota[dimension=0 dtype=int32 shape=(5,) sharding=None]
d:i32<host>[5] = device_put[
copy_semantics=(ArrayCopySemantics.REUSE_INPUT,)
devices=(SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=pinned_host),)
srcs=(None,)
] c
e:f32[5] = convert_element_type[new_dtype=float32 weak_type=False] b
f:f32<host>[5] = device_put[
copy_semantics=(ArrayCopySemantics.REUSE_INPUT,)
devices=(SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=pinned_host),)
srcs=(None,)
] e
g:bool<host>[5] = lt d 0:i32[]
h:i32<host>[5] = add d 536870912:i32[]
i:i32<host>[5] = select_n g d h
j:i32<host>[5,1] = broadcast_in_dim[
broadcast_dimensions=(0,)
shape=(5, 1)
sharding=None
] i
k:f32<host>[536870912] = scatter[
dimension_numbers=ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,), operand_batching_dims=(), scatter_indices_batching_dims=())
indices_are_sorted=False
mode=GatherScatterMode.FILL_OR_DROP
unique_indices=False
update_consts=()
update_jaxpr=None
] a j f
in (k,) }I'm afraid I don't fully understand what the error means, so any help would be greatly appreciated
Metadata
Metadata
Assignees
Labels
No labels