@@ -791,7 +791,7 @@ flagcxResult_t flagcxScatter(const void *sendbuff, void *recvbuff, size_t count,
791
791
return cclAdaptors[flagcxCCLAdaptorDevice]->scatter (
792
792
sendbuff, recvbuff, count, datatype, root, comm->homo_comm , stream);
793
793
} else {
794
- if (use_host_comm () || comm->has_single_rank_homo_comm ) {
794
+ if (use_host_comm () || comm->has_single_rank_homo_comm ) {
795
795
// c2c validation
796
796
if (comm->has_single_rank_homo_comm ) {
797
797
WARN (" Host comm is required to perform C2C scatter op when "
@@ -812,8 +812,9 @@ flagcxResult_t flagcxScatter(const void *sendbuff, void *recvbuff, size_t count,
812
812
813
813
// step 2: memcpy d2h
814
814
timers[TIMER_COLL_MEM_D2H] = clockNano ();
815
- deviceAdaptor->deviceMemcpy (buff_in, const_cast <void *>(sendbuff), totalSize,
816
- flagcxMemcpyDeviceToHost, NULL , NULL );
815
+ deviceAdaptor->deviceMemcpy (buff_in, const_cast <void *>(sendbuff),
816
+ totalSize, flagcxMemcpyDeviceToHost, NULL ,
817
+ NULL );
817
818
timers[TIMER_COLL_MEM_D2H] = clockNano () - timers[TIMER_COLL_MEM_D2H];
818
819
819
820
// step 3: scatter
@@ -844,7 +845,7 @@ flagcxResult_t flagcxScatter(const void *sendbuff, void *recvbuff, size_t count,
844
845
timers[TIMER_COLL_TOTAL] / 1e6 , timers[TIMER_COLL_ALLOC] / 1e6 ,
845
846
timers[TIMER_COLL_FREE] / 1e6 , timers[TIMER_COLL_MEM_D2H] / 1e6 ,
846
847
timers[TIMER_COLL_MEM_H2D] / 1e6 , timers[TIMER_COLL_COMM] / 1e6 );
847
- }else {
848
+ } else {
848
849
// Experimental for multi-nic support
849
850
// Construct flagcxC2cPlanner and find corresponding strategy
850
851
flagcxC2cPlanner planner;
@@ -1369,21 +1370,27 @@ flagcxResult_t flagcxAlltoAllv(const void *sendbuff, size_t *sendcounts,
1369
1370
void *buff_out;
1370
1371
1371
1372
// Calculate max possible size needed for send and receive buffers
1372
- size_t max_send_size = 0 , max_recv_size = 0 , send_size = 0 , recv_size = 0 ;
1373
+ size_t max_send_size = 0 , max_recv_size = 0 , send_size = 0 , recv_size = 0 ;
1373
1374
for (int i = 0 ; i < comm->nranks ; i++) {
1374
- send_size = (sendcounts[i] + sdispls[i]) * getFlagcxDataTypeSize (datatype);
1375
- recv_size = (recvcounts[i] + rdispls[i]) * getFlagcxDataTypeSize (datatype);
1376
- if (send_size > max_send_size) max_send_size = send_size;
1377
- if (recv_size > max_recv_size) max_recv_size = recv_size;
1375
+ send_size =
1376
+ (sendcounts[i] + sdispls[i]) * getFlagcxDataTypeSize (datatype);
1377
+ recv_size =
1378
+ (recvcounts[i] + rdispls[i]) * getFlagcxDataTypeSize (datatype);
1379
+ if (send_size > max_send_size)
1380
+ max_send_size = send_size;
1381
+ if (recv_size > max_recv_size)
1382
+ max_recv_size = recv_size;
1378
1383
}
1379
1384
timers[TIMER_COLL_ALLOC] = clockNano ();
1380
1385
deviceAdaptor->deviceMalloc (&buff_in, max_send_size, flagcxMemHost, NULL );
1381
- deviceAdaptor->deviceMalloc (&buff_out, max_recv_size, flagcxMemHost, NULL );
1386
+ deviceAdaptor->deviceMalloc (&buff_out, max_recv_size, flagcxMemHost,
1387
+ NULL );
1382
1388
timers[TIMER_COLL_ALLOC] = clockNano () - timers[TIMER_COLL_ALLOC];
1383
1389
1384
1390
timers[TIMER_COLL_MEM_D2H] = clockNano ();
1385
- deviceAdaptor->deviceMemcpy (buff_in, const_cast <void *>(sendbuff), max_send_size,
1386
- flagcxMemcpyDeviceToHost, NULL , NULL );
1391
+ deviceAdaptor->deviceMemcpy (buff_in, const_cast <void *>(sendbuff),
1392
+ max_send_size, flagcxMemcpyDeviceToHost, NULL ,
1393
+ NULL );
1387
1394
timers[TIMER_COLL_MEM_D2H] = clockNano () - timers[TIMER_COLL_MEM_D2H];
1388
1395
1389
1396
timers[TIMER_COLL_COMM] = clockNano ();
@@ -1405,7 +1412,8 @@ flagcxResult_t flagcxAlltoAllv(const void *sendbuff, size_t *sendcounts,
1405
1412
timers[TIMER_COLL_TOTAL] = clockNano () - timers[TIMER_COLL_TOTAL];
1406
1413
INFO (FLAGCX_COLL,
1407
1414
" Flagcx timings - %s AlltoAllv: rank %d nranks %d total %.2fms "
1408
- " (memory alloc %.2fms, memory free %.2fms, memory d2h %.2fms, memory h2d %.2fms, comm %.2fms)" ,
1415
+ " (memory alloc %.2fms, memory free %.2fms, memory d2h %.2fms, "
1416
+ " memory h2d %.2fms, comm %.2fms)" ,
1409
1417
cclAdaptors[flagcxCCLAdaptorHost]->name , comm->rank , comm->nranks ,
1410
1418
timers[TIMER_COLL_TOTAL] / 1e6 , timers[TIMER_COLL_ALLOC] / 1e6 ,
1411
1419
timers[TIMER_COLL_FREE] / 1e6 , timers[TIMER_COLL_MEM_D2H] / 1e6 ,
@@ -1488,8 +1496,14 @@ flagcxResult_t flagcxSend(const void *sendbuff, size_t count,
1488
1496
timers[TIMER_COLL_TOTAL] / 1e6 , timers[TIMER_COLL_ALLOC] / 1e6 ,
1489
1497
timers[TIMER_COLL_MEM_D2H] / 1e6 , timers[TIMER_COLL_COMM] / 1e6 );
1490
1498
} else {
1491
- FLAGCXCHECK (flagcxHeteroSend (sendbuff, count, datatype, peer,
1492
- comm->hetero_comm , stream));
1499
+ if (comm->cluster_ids [comm->rank ] == comm->cluster_ids [peer]) {
1500
+ FLAGCXCHECK (cclAdaptors[flagcxCCLAdaptorDevice]->send (
1501
+ sendbuff, count, datatype, comm->globalrank2homorank [peer],
1502
+ comm->homo_comm , stream));
1503
+ } else {
1504
+ FLAGCXCHECK (flagcxHeteroSend (sendbuff, count, datatype, peer,
1505
+ comm->hetero_comm , stream));
1506
+ }
1493
1507
}
1494
1508
}
1495
1509
return flagcxSuccess;
@@ -1541,37 +1555,41 @@ flagcxResult_t flagcxRecv(void *recvbuff, size_t count,
1541
1555
timers[TIMER_COLL_FREE] / 1e6 , timers[TIMER_COLL_MEM_H2D] / 1e6 ,
1542
1556
timers[TIMER_COLL_COMM] / 1e6 );
1543
1557
} else {
1544
- FLAGCXCHECK (flagcxHeteroRecv (recvbuff, count, datatype, peer,
1545
- comm->hetero_comm , stream));
1558
+ if (comm->cluster_ids [comm->rank ] == comm->cluster_ids [peer]) {
1559
+ FLAGCXCHECK (cclAdaptors[flagcxCCLAdaptorDevice]->recv (
1560
+ recvbuff, count, datatype, comm->globalrank2homorank [peer],
1561
+ comm->homo_comm , stream));
1562
+ } else {
1563
+ FLAGCXCHECK (flagcxHeteroRecv (recvbuff, count, datatype, peer,
1564
+ comm->hetero_comm , stream));
1565
+ }
1546
1566
}
1547
1567
}
1548
1568
return flagcxSuccess;
1549
1569
}
1550
1570
1551
1571
flagcxResult_t flagcxGroupStart (flagcxComm_t comm) {
1552
1572
FLAGCXCHECK (flagcxEnsureCommReady (comm));
1553
- if (is_homo_comm (comm)) {
1554
- return cclAdaptors[flagcxCCLAdaptorDevice]->groupStart ();
1573
+ if (!is_homo_comm (comm)) {
1574
+ FLAGCXCHECK (flagcxHeteroGroupStart ());
1575
+ }
1576
+ if (use_host_comm ()) {
1577
+ FLAGCXCHECK (cclAdaptors[flagcxCCLAdaptorHost]->groupStart ());
1555
1578
} else {
1556
- if (use_host_comm ()) {
1557
- cclAdaptors[flagcxCCLAdaptorHost]->groupStart ();
1558
- } else {
1559
- FLAGCXCHECK (flagcxHeteroGroupStart ());
1560
- }
1579
+ FLAGCXCHECK (cclAdaptors[flagcxCCLAdaptorDevice]->groupStart ());
1561
1580
}
1562
1581
return flagcxSuccess;
1563
1582
}
1564
1583
1565
1584
flagcxResult_t flagcxGroupEnd (flagcxComm_t comm) {
1566
1585
FLAGCXCHECK (flagcxEnsureCommReady (comm));
1567
- if (is_homo_comm (comm )) {
1568
- return cclAdaptors[flagcxCCLAdaptorDevice ]->groupEnd ();
1586
+ if (use_host_comm ( )) {
1587
+ FLAGCXCHECK ( cclAdaptors[flagcxCCLAdaptorHost ]->groupEnd () );
1569
1588
} else {
1570
- if (use_host_comm ()) {
1571
- cclAdaptors[flagcxCCLAdaptorHost]->groupEnd ();
1572
- } else {
1573
- FLAGCXCHECK (flagcxHeteroGroupEnd ());
1574
- }
1589
+ FLAGCXCHECK (cclAdaptors[flagcxCCLAdaptorDevice]->groupEnd ());
1590
+ }
1591
+ if (!is_homo_comm (comm)) {
1592
+ FLAGCXCHECK (flagcxHeteroGroupEnd ());
1575
1593
}
1576
1594
return flagcxSuccess;
1577
- }
1595
+ }
0 commit comments