Skip to content

Commit

Permalink
spml/ucx: shuffle EPs creation
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-shalev committed Nov 4, 2024
1 parent afc970c commit b734683
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions oshmem/mca/spml/ucx/spml_ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ static int oshmem_shmem_xchng(
if (NULL == rcv_sizes) {
goto err;
}

rc = oshmem_shmem_allgather(local_size, rcv_sizes, ucp_workers * sizeof(*rcv_sizes));
if (MPI_SUCCESS != rc) {
goto err;
Expand Down Expand Up @@ -634,16 +634,16 @@ int mca_spml_ucx_clear_put_op_mask(mca_spml_ucx_ctx_t *ctx)
int mca_spml_ucx_add_procs(oshmem_group_t* group, size_t nprocs)
{
int rc = OSHMEM_ERROR;
int my_rank = oshmem_my_proc_id();
size_t ucp_workers = mca_spml_ucx.ucp_workers;
unsigned int *wk_roffs = NULL;
unsigned int *wk_rsizes = NULL;
char *wk_raddrs = NULL;
size_t i, w, n;
size_t i, j, w, n, temp;
ucs_status_t err;
ucp_address_t **wk_local_addr;
unsigned int *wk_addr_len;
ucp_ep_params_t ep_params;
int *indices;

wk_local_addr = calloc(mca_spml_ucx.ucp_workers, sizeof(ucp_address_t *));
wk_addr_len = calloc(mca_spml_ucx.ucp_workers, sizeof(size_t));
Expand Down Expand Up @@ -691,23 +691,40 @@ int mca_spml_ucx_add_procs(oshmem_group_t* group, size_t nprocs)
}
}

indices = malloc(nprocs * sizeof(int));
if (!indices) {
goto error;
}

for (i = 0; i < nprocs; i++) {
indices[i] = i;
}

srand((unsigned int)time(NULL));

/* Get the EP connection requests for all the processes from modex */
for (n = 0; n < nprocs; ++n) {
i = (my_rank + n) % nprocs;
for (i = nprocs - 1; i >= 0; --i) {
/* Fisher-Yates shuffle algorithm */
if (i > 0) {
j = rand() % (i + 1);
temp = indices[i];
indices[i] = indices[j];
indices[j] = temp;
}

ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
ep_params.address = (ucp_address_t *)mca_spml_ucx.remote_addrs_tbl[0][i];
ep_params.address = (ucp_address_t *) mca_spml_ucx.remote_addrs_tbl[0][indices[i]];

err = ucp_ep_create(mca_spml_ucx_ctx_default.ucp_worker[0], &ep_params,
&mca_spml_ucx_ctx_default.ucp_peers[i].ucp_conn);
&mca_spml_ucx_ctx_default.ucp_peers[indices[i]].ucp_conn);
if (UCS_OK != err) {
SPML_UCX_ERROR("ucp_ep_create(proc=%zu/%zu) failed: %s", n, nprocs,
ucs_status_string(err));
goto error2;
}

/* Initialize mkeys as NULL for all processes */
mca_spml_ucx_peer_mkey_cache_init(&mca_spml_ucx_ctx_default, i);
mca_spml_ucx_peer_mkey_cache_init(&mca_spml_ucx_ctx_default, indices[i]);
}

for (i = 0; i < mca_spml_ucx.ucp_workers; i++) {
Expand All @@ -719,6 +736,7 @@ int mca_spml_ucx_add_procs(oshmem_group_t* group, size_t nprocs)
free(wk_roffs);
free(wk_addr_len);
free(wk_local_addr);
free(indices);

SPML_UCX_VERBOSE(50, "*** ADDED PROCS ***");

Expand Down Expand Up @@ -753,6 +771,7 @@ int mca_spml_ucx_add_procs(oshmem_group_t* group, size_t nprocs)
free(wk_raddrs);
free(wk_rsizes);
free(wk_roffs);
free(indices);
error:
free(wk_addr_len);
free(wk_local_addr);
Expand Down Expand Up @@ -1025,7 +1044,7 @@ static inline void _ctx_remove(mca_spml_ucx_ctx_array_t *array, mca_spml_ucx_ctx

opal_atomic_wmb ();
}

static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx_ctx_p)
{
ucp_worker_params_t params;
Expand All @@ -1044,7 +1063,7 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx
ucx_ctx->ucp_worker = calloc(1, sizeof(ucp_worker_h));
ucx_ctx->ucp_workers = 1;
ucx_ctx->synchronized_quiet = mca_spml_ucx_ctx_default.synchronized_quiet;
ucx_ctx->strong_sync = mca_spml_ucx_ctx_default.strong_sync;
ucx_ctx->strong_sync = mca_spml_ucx_ctx_default.strong_sync;

params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
if (oshmem_mpi_thread_provided == SHMEM_THREAD_SINGLE ||
Expand Down

0 comments on commit b734683

Please sign in to comment.