Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/ucp/api/device/ucp_device_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ typedef struct ucp_device_request {
uct_device_completion_t comp;
ucs_status_t status;
uct_device_ep_h device_ep;
unsigned channel_id;
} ucp_device_request_t;


Expand Down Expand Up @@ -162,7 +161,7 @@ UCS_F_DEVICE ucs_status_t ucp_device_put_single(

return UCP_DEVICE_SEND_BLOCKING(level, uct_device_ep_put_single, device_ep,
req, uct_elem, address, remote_address,
length, flags, comp);
length, channel_id, flags, comp);
}


Expand Down Expand Up @@ -219,7 +218,7 @@ UCS_F_DEVICE ucs_status_t ucp_device_counter_inc(

return UCP_DEVICE_SEND_BLOCKING(level, uct_device_ep_atomic_add, device_ep,
req, uct_elem, inc_value, remote_address,
flags, comp);
channel_id, flags, comp);
}


Expand Down Expand Up @@ -286,7 +285,7 @@ UCS_F_DEVICE ucs_status_t ucp_device_put_multi(
mem_list_h->mem_list_length, addresses,
remote_addresses, lengths,
counter_inc_value, counter_remote_address,
flags, comp);
channel_id, flags, comp);
}


Expand Down Expand Up @@ -373,7 +372,7 @@ UCS_F_DEVICE ucs_status_t ucp_device_put_multi_partial(
remote_addresses, local_offsets,
remote_offsets, lengths, counter_index,
counter_inc_value, counter_remote_address,
flags, comp);
channel_id, flags, comp);
}


Expand Down
23 changes: 14 additions & 9 deletions src/uct/api/device/uct_device_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ union uct_device_completion {
* @param [in] address Local virtual address to send data from.
* @param [in] remote_address Remote virtual address to write data to.
* @param [in] length Length in bytes of the data to send.
* @param [in] channel_id Channel ID to use for the transfer.
* @param [in] flags Flags to modify the function behavior.
* @param [in] comp Completion object to track the progress of operation.
*
Expand All @@ -53,12 +54,13 @@ template<ucs_device_level_t level>
UCS_F_DEVICE ucs_status_t uct_device_ep_put_single(
uct_device_ep_h device_ep, const uct_device_mem_element_t *mem_elem,
const void *address, uint64_t remote_address, size_t length,
uint64_t flags, uct_device_completion_t *comp)
unsigned channel_id, uint64_t flags, uct_device_completion_t *comp)
{
if (device_ep->uct_tl_id == UCT_DEVICE_TL_RC_MLX5_GDA) {
return uct_rc_mlx5_gda_ep_put_single<level>(device_ep, mem_elem,
address, remote_address,
length, flags, comp);
length, channel_id, flags,
comp);
} else if (device_ep->uct_tl_id == UCT_DEVICE_TL_CUDA_IPC) {
return uct_cuda_ipc_ep_put_single<level>(device_ep, mem_elem, address,
remote_address, length, flags,
Expand All @@ -85,6 +87,7 @@ UCS_F_DEVICE ucs_status_t uct_device_ep_put_single(
* @param [in] mem_elem Memory element representing the memory to be modified.
* @param [in] inc_value Value of the remote increment.
* @param [in] remote_address Remote virtual address to write data to.
* @param [in] channel_id Channel ID to use for the transfer.
* @param [in] flags Flags to modify the function behavior.
* @param [in] comp Completion object to track the progress of operation.
*
Expand All @@ -98,13 +101,13 @@ UCS_F_DEVICE ucs_status_t uct_device_ep_put_single(
template<ucs_device_level_t level>
UCS_F_DEVICE ucs_status_t uct_device_ep_atomic_add(
uct_device_ep_h device_ep, const uct_device_mem_element_t *mem_elem,
uint64_t inc_value, uint64_t remote_address, uint64_t flags,
uct_device_completion_t *comp)
uint64_t inc_value, uint64_t remote_address, unsigned channel_id,
uint64_t flags, uct_device_completion_t *comp)
{
if (device_ep->uct_tl_id == UCT_DEVICE_TL_RC_MLX5_GDA) {
return uct_rc_mlx5_gda_ep_atomic_add<level>(device_ep, mem_elem,
inc_value, remote_address,
flags, comp);
channel_id, flags, comp);
} else if (device_ep->uct_tl_id == UCT_DEVICE_TL_CUDA_IPC) {
return uct_cuda_ipc_ep_atomic_add<level>(device_ep, mem_elem, inc_value,
remote_address, flags, comp);
Expand Down Expand Up @@ -146,6 +149,7 @@ UCS_F_DEVICE ucs_status_t uct_device_ep_atomic_add(
* @param [in] lengths Array of lengths in bytes for each send.
* @param [in] counter_inc_value Value of the remote increment.
* @param [in] counter_remote_address Remote address to increment to.
* @param [in] channel_id Channel ID to use for the transfer.
* @param [in] flags Flags to modify the function behavior.
* @param [out] req Request populated by the call.
*
Expand All @@ -162,15 +166,15 @@ UCS_F_DEVICE ucs_status_t uct_device_ep_put_multi(
unsigned mem_list_count, void *const *addresses,
const uint64_t *remote_addresses, const size_t *lengths,
uint64_t counter_inc_value, uint64_t counter_remote_address,
uint64_t flags, uct_device_completion_t *comp)
unsigned channel_id, uint64_t flags, uct_device_completion_t *comp)
{
if (device_ep->uct_tl_id == UCT_DEVICE_TL_RC_MLX5_GDA) {
return uct_rc_mlx5_gda_ep_put_multi<level>(device_ep, mem_list,
mem_list_count, addresses,
remote_addresses, lengths,
counter_inc_value,
counter_remote_address,
flags, comp);
channel_id, flags, comp);
} else if (device_ep->uct_tl_id == UCT_DEVICE_TL_CUDA_IPC) {
return uct_cuda_ipc_ep_put_multi<level>(device_ep, mem_list,
mem_list_count, addresses,
Expand Down Expand Up @@ -227,6 +231,7 @@ UCS_F_DEVICE ucs_status_t uct_device_ep_put_multi(
* @param [in] counter_index Index of remote increment descriptor.
* @param [in] counter_inc_value Value of the remote increment.
* @param [in] counter_remote_address Remote address to increment to.
* @param [in] channel_id Channel ID to use for the transfer.
* @param [in] flags Flags to modify the function behavior.
* @param [in] comp Completion object to track progress.
*
Expand All @@ -245,14 +250,14 @@ UCS_F_DEVICE ucs_status_t uct_device_ep_put_multi_partial(
const size_t *local_offsets, const size_t *remote_offsets,
const size_t *lengths, unsigned counter_index,
uint64_t counter_inc_value, uint64_t counter_remote_address,
uint64_t flags, uct_device_completion_t *comp)
unsigned channel_id, uint64_t flags, uct_device_completion_t *comp)
{
if (device_ep->uct_tl_id == UCT_DEVICE_TL_RC_MLX5_GDA) {
return uct_rc_mlx5_gda_ep_put_multi_partial<level>(
device_ep, mem_list, mem_list_indices, mem_list_count,
addresses, remote_addresses, local_offsets, remote_offsets,
lengths, counter_index, counter_inc_value,
counter_remote_address, flags, comp);
counter_remote_address, channel_id, flags, comp);
} else if (device_ep->uct_tl_id == UCT_DEVICE_TL_CUDA_IPC) {
return uct_cuda_ipc_ep_put_multi_partial<level>(
device_ep, mem_list, mem_list_indices, mem_list_count,
Expand Down
Loading
Loading