Skip to content

Commit bf609db

Browse files
authored
Use homo sendrecv in flagcxSend/Recv for intra-cluster communication (#182)
1 parent 21288ce commit bf609db

File tree

2 files changed

+55
-37
lines changed

2 files changed

+55
-37
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ FlagCX leverages native collective communications libraries to provide the full
2323
| Mode | Homo | Homo | Homo | Homo | Homo | Homo | Hetero | Hetero | Hetero |
2424
| send ||||||||||
2525
| recv ||||||||||
26-
| broadcast |||||||| ||
27-
| gather |||||||| ||
28-
| scatter |||||||| ||
29-
| reduce |||||||| ||
26+
| broadcast |||||||| ||
27+
| gather |||||||| ||
28+
| scatter |||||||| ||
29+
| reduce |||||||| ||
3030
| allreduce ||||||||||
3131
| allgather ||||||||||
3232
| reducescatter ||||||||||
3333
| alltoall ||||||||||
34-
| alltoallv ||||||| |||
34+
| alltoallv ||||||| |||
3535
| group ops ||||||||||
3636

3737
Note that `Homo` and `Hetero` modes refer to communications among homogeneous and heterogeneous clusters. Except for `BOOTSTRAP` (which is constructed by FlagCX `bootstrap` component), all other native collective communications libraries can be referenced through the links below:

flagcx/flagcx.cc

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,7 @@ flagcxResult_t flagcxScatter(const void *sendbuff, void *recvbuff, size_t count,
791791
return cclAdaptors[flagcxCCLAdaptorDevice]->scatter(
792792
sendbuff, recvbuff, count, datatype, root, comm->homo_comm, stream);
793793
} else {
794-
if (use_host_comm() || comm->has_single_rank_homo_comm) {
794+
if (use_host_comm() || comm->has_single_rank_homo_comm) {
795795
// c2c validation
796796
if (comm->has_single_rank_homo_comm) {
797797
WARN("Host comm is required to perform C2C scatter op when "
@@ -812,8 +812,9 @@ flagcxResult_t flagcxScatter(const void *sendbuff, void *recvbuff, size_t count,
812812

813813
// step 2: memcpy d2h
814814
timers[TIMER_COLL_MEM_D2H] = clockNano();
815-
deviceAdaptor->deviceMemcpy(buff_in, const_cast<void *>(sendbuff), totalSize,
816-
flagcxMemcpyDeviceToHost, NULL, NULL);
815+
deviceAdaptor->deviceMemcpy(buff_in, const_cast<void *>(sendbuff),
816+
totalSize, flagcxMemcpyDeviceToHost, NULL,
817+
NULL);
817818
timers[TIMER_COLL_MEM_D2H] = clockNano() - timers[TIMER_COLL_MEM_D2H];
818819

819820
// step 3: scatter
@@ -844,7 +845,7 @@ flagcxResult_t flagcxScatter(const void *sendbuff, void *recvbuff, size_t count,
844845
timers[TIMER_COLL_TOTAL] / 1e6, timers[TIMER_COLL_ALLOC] / 1e6,
845846
timers[TIMER_COLL_FREE] / 1e6, timers[TIMER_COLL_MEM_D2H] / 1e6,
846847
timers[TIMER_COLL_MEM_H2D] / 1e6, timers[TIMER_COLL_COMM] / 1e6);
847-
}else {
848+
} else {
848849
// Experimental for multi-nic support
849850
// Construct flagcxC2cPlanner and find corresponding strategy
850851
flagcxC2cPlanner planner;
@@ -1369,21 +1370,27 @@ flagcxResult_t flagcxAlltoAllv(const void *sendbuff, size_t *sendcounts,
13691370
void *buff_out;
13701371

13711372
// Calculate max possible size needed for send and receive buffers
1372-
size_t max_send_size = 0, max_recv_size = 0 , send_size = 0 , recv_size = 0;
1373+
size_t max_send_size = 0, max_recv_size = 0, send_size = 0, recv_size = 0;
13731374
for (int i = 0; i < comm->nranks; i++) {
1374-
send_size = (sendcounts[i] + sdispls[i]) * getFlagcxDataTypeSize(datatype);
1375-
recv_size = (recvcounts[i] + rdispls[i]) * getFlagcxDataTypeSize(datatype);
1376-
if (send_size > max_send_size) max_send_size = send_size;
1377-
if (recv_size > max_recv_size) max_recv_size = recv_size;
1375+
send_size =
1376+
(sendcounts[i] + sdispls[i]) * getFlagcxDataTypeSize(datatype);
1377+
recv_size =
1378+
(recvcounts[i] + rdispls[i]) * getFlagcxDataTypeSize(datatype);
1379+
if (send_size > max_send_size)
1380+
max_send_size = send_size;
1381+
if (recv_size > max_recv_size)
1382+
max_recv_size = recv_size;
13781383
}
13791384
timers[TIMER_COLL_ALLOC] = clockNano();
13801385
deviceAdaptor->deviceMalloc(&buff_in, max_send_size, flagcxMemHost, NULL);
1381-
deviceAdaptor->deviceMalloc(&buff_out, max_recv_size, flagcxMemHost, NULL);
1386+
deviceAdaptor->deviceMalloc(&buff_out, max_recv_size, flagcxMemHost,
1387+
NULL);
13821388
timers[TIMER_COLL_ALLOC] = clockNano() - timers[TIMER_COLL_ALLOC];
13831389

13841390
timers[TIMER_COLL_MEM_D2H] = clockNano();
1385-
deviceAdaptor->deviceMemcpy(buff_in, const_cast<void *>(sendbuff), max_send_size,
1386-
flagcxMemcpyDeviceToHost, NULL, NULL);
1391+
deviceAdaptor->deviceMemcpy(buff_in, const_cast<void *>(sendbuff),
1392+
max_send_size, flagcxMemcpyDeviceToHost, NULL,
1393+
NULL);
13871394
timers[TIMER_COLL_MEM_D2H] = clockNano() - timers[TIMER_COLL_MEM_D2H];
13881395

13891396
timers[TIMER_COLL_COMM] = clockNano();
@@ -1405,7 +1412,8 @@ flagcxResult_t flagcxAlltoAllv(const void *sendbuff, size_t *sendcounts,
14051412
timers[TIMER_COLL_TOTAL] = clockNano() - timers[TIMER_COLL_TOTAL];
14061413
INFO(FLAGCX_COLL,
14071414
"Flagcx timings - %s AlltoAllv: rank %d nranks %d total %.2fms "
1408-
"(memory alloc %.2fms, memory free %.2fms, memory d2h %.2fms, memory h2d %.2fms, comm %.2fms)",
1415+
"(memory alloc %.2fms, memory free %.2fms, memory d2h %.2fms, "
1416+
"memory h2d %.2fms, comm %.2fms)",
14091417
cclAdaptors[flagcxCCLAdaptorHost]->name, comm->rank, comm->nranks,
14101418
timers[TIMER_COLL_TOTAL] / 1e6, timers[TIMER_COLL_ALLOC] / 1e6,
14111419
timers[TIMER_COLL_FREE] / 1e6, timers[TIMER_COLL_MEM_D2H] / 1e6,
@@ -1488,8 +1496,14 @@ flagcxResult_t flagcxSend(const void *sendbuff, size_t count,
14881496
timers[TIMER_COLL_TOTAL] / 1e6, timers[TIMER_COLL_ALLOC] / 1e6,
14891497
timers[TIMER_COLL_MEM_D2H] / 1e6, timers[TIMER_COLL_COMM] / 1e6);
14901498
} else {
1491-
FLAGCXCHECK(flagcxHeteroSend(sendbuff, count, datatype, peer,
1492-
comm->hetero_comm, stream));
1499+
if (comm->cluster_ids[comm->rank] == comm->cluster_ids[peer]) {
1500+
FLAGCXCHECK(cclAdaptors[flagcxCCLAdaptorDevice]->send(
1501+
sendbuff, count, datatype, comm->globalrank2homorank[peer],
1502+
comm->homo_comm, stream));
1503+
} else {
1504+
FLAGCXCHECK(flagcxHeteroSend(sendbuff, count, datatype, peer,
1505+
comm->hetero_comm, stream));
1506+
}
14931507
}
14941508
}
14951509
return flagcxSuccess;
@@ -1541,37 +1555,41 @@ flagcxResult_t flagcxRecv(void *recvbuff, size_t count,
15411555
timers[TIMER_COLL_FREE] / 1e6, timers[TIMER_COLL_MEM_H2D] / 1e6,
15421556
timers[TIMER_COLL_COMM] / 1e6);
15431557
} else {
1544-
FLAGCXCHECK(flagcxHeteroRecv(recvbuff, count, datatype, peer,
1545-
comm->hetero_comm, stream));
1558+
if (comm->cluster_ids[comm->rank] == comm->cluster_ids[peer]) {
1559+
FLAGCXCHECK(cclAdaptors[flagcxCCLAdaptorDevice]->recv(
1560+
recvbuff, count, datatype, comm->globalrank2homorank[peer],
1561+
comm->homo_comm, stream));
1562+
} else {
1563+
FLAGCXCHECK(flagcxHeteroRecv(recvbuff, count, datatype, peer,
1564+
comm->hetero_comm, stream));
1565+
}
15461566
}
15471567
}
15481568
return flagcxSuccess;
15491569
}
15501570

15511571
flagcxResult_t flagcxGroupStart(flagcxComm_t comm) {
15521572
FLAGCXCHECK(flagcxEnsureCommReady(comm));
1553-
if (is_homo_comm(comm)) {
1554-
return cclAdaptors[flagcxCCLAdaptorDevice]->groupStart();
1573+
if (!is_homo_comm(comm)) {
1574+
FLAGCXCHECK(flagcxHeteroGroupStart());
1575+
}
1576+
if (use_host_comm()) {
1577+
FLAGCXCHECK(cclAdaptors[flagcxCCLAdaptorHost]->groupStart());
15551578
} else {
1556-
if (use_host_comm()) {
1557-
cclAdaptors[flagcxCCLAdaptorHost]->groupStart();
1558-
} else {
1559-
FLAGCXCHECK(flagcxHeteroGroupStart());
1560-
}
1579+
FLAGCXCHECK(cclAdaptors[flagcxCCLAdaptorDevice]->groupStart());
15611580
}
15621581
return flagcxSuccess;
15631582
}
15641583

15651584
flagcxResult_t flagcxGroupEnd(flagcxComm_t comm) {
15661585
FLAGCXCHECK(flagcxEnsureCommReady(comm));
1567-
if (is_homo_comm(comm)) {
1568-
return cclAdaptors[flagcxCCLAdaptorDevice]->groupEnd();
1586+
if (use_host_comm()) {
1587+
FLAGCXCHECK(cclAdaptors[flagcxCCLAdaptorHost]->groupEnd());
15691588
} else {
1570-
if (use_host_comm()) {
1571-
cclAdaptors[flagcxCCLAdaptorHost]->groupEnd();
1572-
} else {
1573-
FLAGCXCHECK(flagcxHeteroGroupEnd());
1574-
}
1589+
FLAGCXCHECK(cclAdaptors[flagcxCCLAdaptorDevice]->groupEnd());
1590+
}
1591+
if (!is_homo_comm(comm)) {
1592+
FLAGCXCHECK(flagcxHeteroGroupEnd());
15751593
}
15761594
return flagcxSuccess;
1577-
}
1595+
}

0 commit comments

Comments
 (0)