diff --git a/ompi/mca/coll/accelerator/coll_accelerator.h b/ompi/mca/coll/accelerator/coll_accelerator.h index e707d7ec7f2..103710b9da8 100644 --- a/ompi/mca/coll/accelerator/coll_accelerator.h +++ b/ompi/mca/coll/accelerator/coll_accelerator.h @@ -31,6 +31,36 @@ BEGIN_C_DECLS +#define COLL_ACC_ALLGATHER 0x00000001 +#define COLL_ACC_ALLGATHERV 0x00000002 +#define COLL_ACC_ALLREDUCE 0x00000004 +#define COLL_ACC_ALLTOALL 0x00000008 +#define COLL_ACC_ALLTOALLV 0x00000010 +#define COLL_ACC_ALLTOALLW 0x00000020 +#define COLL_ACC_BARRIER 0x00000040 +#define COLL_ACC_BCAST 0x00000080 +#define COLL_ACC_EXSCAN 0x00000100 +#define COLL_ACC_GATHER 0x00000200 +#define COLL_ACC_GATHERV 0x00000400 +#define COLL_ACC_REDUCE 0x00000800 +#define COLL_ACC_REDUCE_SCATTER 0x00001000 +#define COLL_ACC_REDUCE_SCATTER_BLOCK 0x00002000 +#define COLL_ACC_REDUCE_LOCAL 0x00004000 +#define COLL_ACC_SCAN 0x00008000 +#define COLL_ACC_SCATTER 0x00010000 +#define COLL_ACC_SCATTERV 0x00020000 +#define COLL_ACC_NEIGHBOR_ALLGATHER 0x00040000 +#define COLL_ACC_NEIGHBOR_ALLGATHERV 0x00080000 +#define COLL_ACC_NEIGHBOR_ALLTOALL 0x00100000 +#define COLL_ACC_NEIGHBOR_ALLTTOALLV 0x00200000 +#define COLL_ACC_NEIGHBOR_ALLTTOALLW 0x00400000 +#define COLL_ACC_LASTCOLL 0x00800000 + +#define COLL_ACCELERATOR_CTS_STR "allreduce,reduce_scatter_block,reduce_local,reduce,scan,exscan" +#define COLL_ACCELERATOR_CTS COLL_ACC_ALLREDUCE | COLL_ACC_REDUCE | \ + COLL_ACC_REDUCE_SCATTER_BLOCK | COLL_ACC_REDUCE_LOCAL | \ + COLL_ACC_EXSCAN | COLL_ACC_SCAN + /* API functions */ int mca_coll_accelerator_init_query(bool enable_progress_threads, @@ -131,6 +161,8 @@ typedef struct mca_coll_accelerator_component_t { int priority; /* Priority of this component */ int disable_accelerator_coll; /* Force disable of the accelerator collective component */ + char *cts; /* String of collective operations which the component shall register itself */ + uint64_t cts_requested; } mca_coll_accelerator_component_t; /* Globally exported variables */ diff --git a/ompi/mca/coll/accelerator/coll_accelerator_component.c b/ompi/mca/coll/accelerator/coll_accelerator_component.c index c4ba508026b..c6e43cb3468 100644 --- a/ompi/mca/coll/accelerator/coll_accelerator_component.c +++ b/ompi/mca/coll/accelerator/coll_accelerator_component.c @@ -7,6 +7,7 @@ * Copyright (c) 2015 Los Alamos National Security, LLC. All rights * reserved. * Copyright (c) 2024 Triad National Security, LLC. All rights reserved. + * Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -21,6 +22,7 @@ #include "mpi.h" #include "ompi/constants.h" #include "coll_accelerator.h" +#include "opal/util/argv.h" /* * Public string showing the coll ompi_accelerator component version number @@ -31,6 +33,7 @@ const char *mca_coll_accelerator_component_version_string = /* * Local function */ +static int accelerator_open(void); static int accelerator_register(void); /* @@ -52,6 +55,7 @@ mca_coll_accelerator_component_t mca_coll_accelerator_component = { OMPI_RELEASE_VERSION), /* Component open and close functions */ + .mca_open_component = accelerator_open, .mca_register_component_params = accelerator_register, }, .collm_data = { @@ -75,7 +79,8 @@ mca_coll_accelerator_component_t mca_coll_accelerator_component = { static int accelerator_register(void) { (void) mca_base_component_var_register(&mca_coll_accelerator_component.super.collm_version, - "priority", "Priority of the accelerator coll component; only relevant if barrier_before or barrier_after is > 0", + "priority", "Priority of the accelerator coll component; only relevant if barrier_before " + "or barrier_after is > 0", MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_6, MCA_BASE_VAR_SCOPE_READONLY, @@ -88,5 +93,76 @@ static int accelerator_register(void) MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_accelerator_component.disable_accelerator_coll); + mca_coll_accelerator_component.cts = COLL_ACCELERATOR_CTS_STR; + (void)mca_base_component_var_register(&mca_coll_accelerator_component.super.collm_version, + "cts", "Comma separated list of collectives to be enabled", + MCA_BASE_VAR_TYPE_STRING, NULL, 0, MCA_BASE_VAR_FLAG_SETTABLE, + OPAL_INFO_LVL_6, MCA_BASE_VAR_SCOPE_ALL, &mca_coll_accelerator_component.cts); + + return OMPI_SUCCESS; +} + + +/* The string parsing is based on the code available in the coll/ucc component */ +static uint64_t mca_coll_accelerator_str_to_type(const char *str) +{ + if (0 == strcasecmp(str, "allreduce")) { + return COLL_ACC_ALLREDUCE; + } else if (0 == strcasecmp(str, "reduce_scatter_block")) { + return COLL_ACC_REDUCE_SCATTER_BLOCK; + } else if (0 == strcasecmp(str, "reduce_local")) { + return COLL_ACC_REDUCE_LOCAL; + } else if (0 == strcasecmp(str, "reduce")) { + return COLL_ACC_REDUCE; + } else if (0 == strcasecmp(str, "exscan")) { + return COLL_ACC_EXSCAN; + } else if (0 == strcasecmp(str, "scan")) { + return COLL_ACC_SCAN; + } + opal_output(0, "incorrect value for cts: %s, allowed: %s", + str, COLL_ACCELERATOR_CTS_STR); + return COLL_ACC_LASTCOLL; +} + +static void accelerator_init_default_cts(void) +{ + mca_coll_accelerator_component_t *cm = &mca_coll_accelerator_component; + bool disable; + char** cts; + int n_cts, i; + char* str; + uint64_t *ct, c; + + disable = (cm->cts[0] == '^') ? true : false; + cts = opal_argv_split(disable ? (cm->cts + 1) : cm->cts, ','); + n_cts = opal_argv_count(cts); + cm->cts_requested = disable ? COLL_ACCELERATOR_CTS : 0; + for (i = 0; i < n_cts; i++) { + if (('i' == cts[i][0]) || ('I' == cts[i][0])) { + /* non blocking collective setting */ + opal_output(0, "coll/accelerator component does not support non-blocking collectives at this time." + " Ignoring collective: %s\n", cts[i]); + continue; + } else { + str = cts[i]; + ct = &cm->cts_requested; + } + c = mca_coll_accelerator_str_to_type(str); + if (COLL_ACC_LASTCOLL == c) { + *ct = COLL_ACCELERATOR_CTS; + break; + } + if (disable) { + (*ct) &= ~c; + } else { + (*ct) |= c; + } + } + opal_argv_free(cts); +} + +static int accelerator_open(void) +{ + accelerator_init_default_cts(); return OMPI_SUCCESS; } diff --git a/ompi/mca/coll/accelerator/coll_accelerator_module.c b/ompi/mca/coll/accelerator/coll_accelerator_module.c index 4005f6cdec9..b0b318726b0 100644 --- a/ompi/mca/coll/accelerator/coll_accelerator_module.c +++ b/ompi/mca/coll/accelerator/coll_accelerator_module.c @@ -6,7 +6,7 @@ * Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved. * Copyright (c) 2019 Research Organization for Information Science * and Technology (RIST). All rights reserved. - * Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2024 Triad National Security, LLC. All rights reserved. * $COPYRIGHT$ * @@ -106,18 +106,21 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm, } -#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api) \ +#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api, __API) \ do \ { \ if ((__comm)->c_coll->coll_##__api) \ { \ - MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \ - MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \ + if (mca_coll_accelerator_component.cts_requested & COLL_ACC_##__API) \ + { \ + MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \ + MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \ + } \ } \ else \ { \ opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \ - "cuda", #__api, ompi_process_info.nodename, \ + "accelerator", #__api, ompi_process_info.nodename, \ mca_coll_accelerator_component.priority); \ } \ } while (0) @@ -141,14 +144,14 @@ mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, { mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module; - ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce); - ACCELERATOR_INSTALL_COLL_API(comm, s, reduce); - ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_local); - ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block); + ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce, ALLREDUCE); + ACCELERATOR_INSTALL_COLL_API(comm, s, reduce, REDUCE); + ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_local, REDUCE_LOCAL); + ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block, REDUCE_SCATTER_BLOCK); if (!OMPI_COMM_IS_INTER(comm)) { /* MPI does not define scan/exscan on intercommunicators */ - ACCELERATOR_INSTALL_COLL_API(comm, s, exscan); - ACCELERATOR_INSTALL_COLL_API(comm, s, scan); + ACCELERATOR_INSTALL_COLL_API(comm, s, exscan, EXSCAN); + ACCELERATOR_INSTALL_COLL_API(comm, s, scan, SCAN); } return OMPI_SUCCESS;