@@ -45,6 +45,7 @@ limitations under the License.
4545#include " xla/debug_options_flags.h"
4646#include " xla/hlo/ir/collective_op_group_mode.h"
4747#include " xla/primitive_util.h"
48+ #include " xla/runtime/device_id.h"
4849#include " xla/service/collective_ops_utils.h"
4950#include " xla/service/computation_placer.h"
5051#include " xla/service/global_device_id.h"
@@ -197,10 +198,12 @@ CollectiveThunk::CollectiveThunk(Kind kind, ThunkInfo thunk_info, bool is_sync,
197198 async_events_(is_sync ? nullptr : std::make_shared<AsyncEvents>()) {}
198199
199200absl::StatusOr<GpuCliqueKey> GetGpuCliqueKey (
200- GpuCollectives* collectives, const CollectiveParams& params,
201- const std::vector< ReplicaGroup>& replica_groups,
201+ const CollectiveParams& params,
202+ absl::Span< const ReplicaGroup> replica_groups,
202203 CollectiveOpGroupMode group_mode, AsyncStreamKind stream_kind,
203204 bool use_nccl) {
205+ TF_RET_CHECK (params.collectives ) << " Collectives API is not provided" ;
206+
204207 GlobalDeviceId global_device_id = params.global_device_id ;
205208
206209 if (params.device_assn == nullptr ) {
@@ -231,7 +234,7 @@ absl::StatusOr<GpuCliqueKey> GetGpuCliqueKey(
231234 *params.device_assn , replica_groups, group_mode));
232235 }
233236
234- if (collectives->IsGlobalConfig () &&
237+ if (params. collectives ->IsGlobalConfig () &&
235238 (participants.size () != params.device_assn ->replica_count ())) {
236239 return InvalidArgument (
237240 " Partial replica groups are not allowed when using NCCL_COMM_ID "
@@ -272,22 +275,18 @@ absl::StatusOr<GpuCliqueKey> GetGpuCliqueKey(
272275absl::StatusOr<GpuCliqueKey> GetCollectiveGpuCliqueKey (
273276 const CollectiveParams& params, const CollectiveConfig& collective_config,
274277 bool use_nccl) {
275- TF_ASSIGN_OR_RETURN (GpuCollectives * collectives,
276- CollectiveThunk::GetGpuCollectives (params));
277- return GetGpuCliqueKey (collectives, params, collective_config.replica_groups ,
278- collective_config.group_mode ,
279- AsyncStreamKind::ASYNC_STREAM_KIND_COLLECTIVE,
280- use_nccl);
278+ return GetGpuCliqueKey (
279+ params, collective_config.replica_groups , collective_config.group_mode ,
280+ AsyncStreamKind::ASYNC_STREAM_KIND_COLLECTIVE, use_nccl);
281281}
282282
283283absl::StatusOr<CommunicatorHandle> GetComm (
284- GpuCollectives* collectives, const CollectiveParams& params,
285- const CollectiveCliques& collective_cliques,
284+ const CollectiveParams& params, const CollectiveCliques& collective_cliques,
286285 const std::vector<ReplicaGroup>& replica_groups,
287286 CollectiveOpGroupMode group_mode, AsyncStreamKind stream_kind) {
288- TF_ASSIGN_OR_RETURN (GpuCliqueKey clique_key,
289- GetGpuCliqueKey (collectives, params, replica_groups ,
290- group_mode, stream_kind));
287+ TF_ASSIGN_OR_RETURN (
288+ GpuCliqueKey clique_key ,
289+ GetGpuCliqueKey (params, replica_groups, group_mode, stream_kind));
291290
292291 std::optional<RankId> rank = clique_key.rank (params.global_device_id );
293292 TF_ASSIGN_OR_RETURN (Communicator * comm,
@@ -386,12 +385,10 @@ absl::StatusOr<se::Event*> CollectiveThunk::AsyncEvents::GetEvent(
386385
387386absl::Status CollectiveThunk::Prepare (const PrepareParams& params) {
388387 TF_RET_CHECK (params.collective_params != nullptr );
389- TF_ASSIGN_OR_RETURN (GpuCollectives * collectives, GetGpuCollectives (params));
390388 TF_ASSIGN_OR_RETURN (
391389 GpuCliqueKey clique_key,
392- GetGpuCliqueKey (collectives, *params.collective_params ,
393- config ().replica_groups , config ().group_mode ,
394- GetAsyncStreamKind ()));
390+ GetGpuCliqueKey (*params.collective_params , config ().replica_groups ,
391+ config ().group_mode , GetAsyncStreamKind ()));
395392 return params.clique_requests ->RequestClique (clique_key);
396393}
397394
@@ -407,12 +404,10 @@ absl::Status CollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
407404 " [%d] Starting %s %s." , params.stream ->parent ()->device_ordinal (),
408405 IsAsync () ? " async" : " sync" , Thunk::KindToString (kind ()));
409406 AsyncStreamKind stream_kind = GetAsyncStreamKind ();
410- TF_ASSIGN_OR_RETURN (GpuCollectives * collectives, GetGpuCollectives (params));
411407 TF_ASSIGN_OR_RETURN (
412408 CommunicatorHandle comm_handle,
413- GetComm (collectives, *params.collective_params ,
414- *params.collective_cliques , config ().replica_groups ,
415- config ().group_mode , stream_kind));
409+ GetComm (*params.collective_params , *params.collective_cliques ,
410+ config ().replica_groups , config ().group_mode , stream_kind));
416411 se::StreamExecutor* executor = params.stream ->parent ();
417412 int64_t async_stream_idx = static_cast <int64_t >(stream_kind);
418413
@@ -485,12 +480,10 @@ absl::Status CollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
485480absl::StatusOr<std::vector<Communicator*>> CollectiveThunk::GetCommunicators (
486481 const ExecuteParams& params) const {
487482 AsyncStreamKind stream_kind = GetAsyncStreamKind ();
488- TF_ASSIGN_OR_RETURN (GpuCollectives * collectives, GetGpuCollectives (params));
489483 TF_ASSIGN_OR_RETURN (
490484 CommunicatorHandle comm_handle,
491- GetComm (collectives, *params.collective_params ,
492- *params.collective_cliques , config ().replica_groups ,
493- config ().group_mode , stream_kind));
485+ GetComm (*params.collective_params , *params.collective_cliques ,
486+ config ().replica_groups , config ().group_mode , stream_kind));
494487 return std::vector<Communicator*>{comm_handle.comm };
495488}
496489
0 commit comments