Skip to content

Commit ba9b5e5

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
[xla:gpu] Remove redundant Thunk::GetGpuCollectives
PiperOrigin-RevId: 837910560
1 parent ceca86d commit ba9b5e5

File tree

8 files changed

+37
-90
lines changed

8 files changed

+37
-90
lines changed

xla/backends/gpu/runtime/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,6 +1668,7 @@ cc_library(
16681668
"//xla/core/collectives:rank_id",
16691669
"//xla/hlo/ir:collective_op_group_mode",
16701670
"//xla/hlo/ir:hlo",
1671+
"//xla/runtime:device_id",
16711672
"//xla/service:buffer_assignment",
16721673
"//xla/service:collective_ops_utils",
16731674
"//xla/service:computation_placer",

xla/backends/gpu/runtime/all_to_all_thunk.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,11 @@ absl::Status AllToAllStartThunk::Initialize(const InitializeParams& params) {
121121
<< "] Local device count : " << device_count_;
122122

123123
if (is_local() && p2p_memcpy_enabled_) {
124-
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives,
125-
GetGpuCollectives(params));
126124
AsyncStreamKind stream_kind = GetAsyncStreamKind();
127125
TF_ASSIGN_OR_RETURN(
128126
CommunicatorHandle comm_handle,
129-
GetComm(collectives, *params.collective_params,
130-
*params.collective_cliques, config().replica_groups,
131-
config().group_mode, stream_kind));
127+
GetComm(*params.collective_params, *params.collective_cliques,
128+
config().replica_groups, config().group_mode, stream_kind));
132129
TF_ASSIGN_OR_RETURN(int32_t num_ranks, comm_handle.comm->NumRanks());
133130
se::StreamExecutor* executor = params.executor;
134131
{

xla/backends/gpu/runtime/collective_thunk.cc

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ limitations under the License.
4545
#include "xla/debug_options_flags.h"
4646
#include "xla/hlo/ir/collective_op_group_mode.h"
4747
#include "xla/primitive_util.h"
48+
#include "xla/runtime/device_id.h"
4849
#include "xla/service/collective_ops_utils.h"
4950
#include "xla/service/computation_placer.h"
5051
#include "xla/service/global_device_id.h"
@@ -197,10 +198,12 @@ CollectiveThunk::CollectiveThunk(Kind kind, ThunkInfo thunk_info, bool is_sync,
197198
async_events_(is_sync ? nullptr : std::make_shared<AsyncEvents>()) {}
198199

199200
absl::StatusOr<GpuCliqueKey> GetGpuCliqueKey(
200-
GpuCollectives* collectives, const CollectiveParams& params,
201-
const std::vector<ReplicaGroup>& replica_groups,
201+
const CollectiveParams& params,
202+
absl::Span<const ReplicaGroup> replica_groups,
202203
CollectiveOpGroupMode group_mode, AsyncStreamKind stream_kind,
203204
bool use_nccl) {
205+
TF_RET_CHECK(params.collectives) << "Collectives API is not provided";
206+
204207
GlobalDeviceId global_device_id = params.global_device_id;
205208

206209
if (params.device_assn == nullptr) {
@@ -231,7 +234,7 @@ absl::StatusOr<GpuCliqueKey> GetGpuCliqueKey(
231234
*params.device_assn, replica_groups, group_mode));
232235
}
233236

234-
if (collectives->IsGlobalConfig() &&
237+
if (params.collectives->IsGlobalConfig() &&
235238
(participants.size() != params.device_assn->replica_count())) {
236239
return InvalidArgument(
237240
"Partial replica groups are not allowed when using NCCL_COMM_ID "
@@ -272,22 +275,18 @@ absl::StatusOr<GpuCliqueKey> GetGpuCliqueKey(
272275
absl::StatusOr<GpuCliqueKey> GetCollectiveGpuCliqueKey(
273276
const CollectiveParams& params, const CollectiveConfig& collective_config,
274277
bool use_nccl) {
275-
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives,
276-
CollectiveThunk::GetGpuCollectives(params));
277-
return GetGpuCliqueKey(collectives, params, collective_config.replica_groups,
278-
collective_config.group_mode,
279-
AsyncStreamKind::ASYNC_STREAM_KIND_COLLECTIVE,
280-
use_nccl);
278+
return GetGpuCliqueKey(
279+
params, collective_config.replica_groups, collective_config.group_mode,
280+
AsyncStreamKind::ASYNC_STREAM_KIND_COLLECTIVE, use_nccl);
281281
}
282282

283283
absl::StatusOr<CommunicatorHandle> GetComm(
284-
GpuCollectives* collectives, const CollectiveParams& params,
285-
const CollectiveCliques& collective_cliques,
284+
const CollectiveParams& params, const CollectiveCliques& collective_cliques,
286285
const std::vector<ReplicaGroup>& replica_groups,
287286
CollectiveOpGroupMode group_mode, AsyncStreamKind stream_kind) {
288-
TF_ASSIGN_OR_RETURN(GpuCliqueKey clique_key,
289-
GetGpuCliqueKey(collectives, params, replica_groups,
290-
group_mode, stream_kind));
287+
TF_ASSIGN_OR_RETURN(
288+
GpuCliqueKey clique_key,
289+
GetGpuCliqueKey(params, replica_groups, group_mode, stream_kind));
291290

292291
std::optional<RankId> rank = clique_key.rank(params.global_device_id);
293292
TF_ASSIGN_OR_RETURN(Communicator * comm,
@@ -386,12 +385,10 @@ absl::StatusOr<se::Event*> CollectiveThunk::AsyncEvents::GetEvent(
386385

387386
absl::Status CollectiveThunk::Prepare(const PrepareParams& params) {
388387
TF_RET_CHECK(params.collective_params != nullptr);
389-
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params));
390388
TF_ASSIGN_OR_RETURN(
391389
GpuCliqueKey clique_key,
392-
GetGpuCliqueKey(collectives, *params.collective_params,
393-
config().replica_groups, config().group_mode,
394-
GetAsyncStreamKind()));
390+
GetGpuCliqueKey(*params.collective_params, config().replica_groups,
391+
config().group_mode, GetAsyncStreamKind()));
395392
return params.clique_requests->RequestClique(clique_key);
396393
}
397394

@@ -407,12 +404,10 @@ absl::Status CollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
407404
"[%d] Starting %s %s.", params.stream->parent()->device_ordinal(),
408405
IsAsync() ? "async" : "sync", Thunk::KindToString(kind()));
409406
AsyncStreamKind stream_kind = GetAsyncStreamKind();
410-
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params));
411407
TF_ASSIGN_OR_RETURN(
412408
CommunicatorHandle comm_handle,
413-
GetComm(collectives, *params.collective_params,
414-
*params.collective_cliques, config().replica_groups,
415-
config().group_mode, stream_kind));
409+
GetComm(*params.collective_params, *params.collective_cliques,
410+
config().replica_groups, config().group_mode, stream_kind));
416411
se::StreamExecutor* executor = params.stream->parent();
417412
int64_t async_stream_idx = static_cast<int64_t>(stream_kind);
418413

@@ -485,12 +480,10 @@ absl::Status CollectiveThunk::ExecuteOnStream(const ExecuteParams& params) {
485480
absl::StatusOr<std::vector<Communicator*>> CollectiveThunk::GetCommunicators(
486481
const ExecuteParams& params) const {
487482
AsyncStreamKind stream_kind = GetAsyncStreamKind();
488-
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params));
489483
TF_ASSIGN_OR_RETURN(
490484
CommunicatorHandle comm_handle,
491-
GetComm(collectives, *params.collective_params,
492-
*params.collective_cliques, config().replica_groups,
493-
config().group_mode, stream_kind));
485+
GetComm(*params.collective_params, *params.collective_cliques,
486+
config().replica_groups, config().group_mode, stream_kind));
494487
return std::vector<Communicator*>{comm_handle.comm};
495488
}
496489

xla/backends/gpu/runtime/collective_thunk.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <cstddef>
22

3+
#include "absl/types/span.h"
34
#include "xla/backends/gpu/runtime/collective_cliques.h"
45
#include "xla/backends/gpu/runtime/thunk.pb.h"
56
/* Copyright 2019 The OpenXLA Authors.
@@ -273,8 +274,8 @@ absl::Status AddOpDescription(absl::Status status, OpT op,
273274
//===----------------------------------------------------------------------===//
274275

275276
absl::StatusOr<GpuCliqueKey> GetGpuCliqueKey(
276-
GpuCollectives* collectives, const CollectiveParams& params,
277-
const std::vector<ReplicaGroup>& replica_groups,
277+
const CollectiveParams& params,
278+
absl::Span<const ReplicaGroup> replica_groups,
278279
CollectiveOpGroupMode group_mode, AsyncStreamKind stream_kind,
279280
bool use_nccl = true);
280281

@@ -285,8 +286,7 @@ absl::StatusOr<GpuCliqueKey> GetCollectiveGpuCliqueKey(
285286

286287
// Returns a communicator and additional information about the clique.
287288
absl::StatusOr<CommunicatorHandle> GetComm(
288-
GpuCollectives* collectives, const CollectiveParams& params,
289-
const CollectiveCliques& collective_cliques,
289+
const CollectiveParams& params, const CollectiveCliques& collective_cliques,
290290
const std::vector<ReplicaGroup>& replica_groups,
291291
CollectiveOpGroupMode group_mode, AsyncStreamKind stream_kind);
292292

xla/backends/gpu/runtime/command_buffer_cmd.cc

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2074,12 +2074,10 @@ CollectiveCmd::CollectiveCmd(
20742074

20752075
absl::Status CollectiveCmd::Prepare(const Thunk::PrepareParams& params) {
20762076
TF_RET_CHECK(params.collective_params != nullptr);
2077-
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives,
2078-
Thunk::GetGpuCollectives(params));
20792077
TF_ASSIGN_OR_RETURN(
20802078
GpuCliqueKey clique_key,
2081-
GetGpuCliqueKey(collectives, *params.collective_params,
2082-
config().replica_groups, config().group_mode,
2079+
GetGpuCliqueKey(*params.collective_params, config().replica_groups,
2080+
config().group_mode,
20832081
AsyncStreamKind::ASYNC_STREAM_KIND_COLLECTIVE));
20842082
return params.clique_requests->RequestClique(clique_key);
20852083
}
@@ -2151,12 +2149,9 @@ absl::StatusOr<const se::CommandBuffer::Command*> AllReduceCmd::Record(
21512149
"AllReduceCmd requires collective parameters and cliques");
21522150
}
21532151

2154-
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives,
2155-
Thunk::GetGpuCollectives(execute_params));
2156-
21572152
TF_ASSIGN_OR_RETURN(
21582153
CommunicatorHandle comm_handle,
2159-
GetComm(collectives, *execute_params.collective_params,
2154+
GetComm(*execute_params.collective_params,
21602155
*execute_params.collective_cliques, config().replica_groups,
21612156
config().group_mode,
21622157
AsyncStreamKind::ASYNC_STREAM_KIND_COLLECTIVE)); // Use constant
@@ -2217,12 +2212,9 @@ absl::StatusOr<const se::CommandBuffer::Command*> ReduceScatterCmd::Record(
22172212
"ReduceScatterCmd requires collective parameters and cliques");
22182213
}
22192214

2220-
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives,
2221-
Thunk::GetGpuCollectives(execute_params));
2222-
22232215
TF_ASSIGN_OR_RETURN(
22242216
CommunicatorHandle comm_handle,
2225-
GetComm(collectives, *execute_params.collective_params,
2217+
GetComm(*execute_params.collective_params,
22262218
*execute_params.collective_cliques, config().replica_groups,
22272219
config().group_mode,
22282220
AsyncStreamKind::ASYNC_STREAM_KIND_COLLECTIVE)); // Use constant
@@ -2284,11 +2276,9 @@ absl::StatusOr<const se::CommandBuffer::Command*> AllToAllCmd::Record(
22842276
"AllToAllCmd requires collective parameters and cliques");
22852277
}
22862278

2287-
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives,
2288-
Thunk::GetGpuCollectives(execute_params));
22892279
TF_ASSIGN_OR_RETURN(
22902280
CommunicatorHandle comm_handle,
2291-
GetComm(collectives, *execute_params.collective_params,
2281+
GetComm(*execute_params.collective_params,
22922282
*execute_params.collective_cliques, config().replica_groups,
22932283
config().group_mode,
22942284
AsyncStreamKind::ASYNC_STREAM_KIND_COLLECTIVE)); // Use constant
@@ -2347,12 +2337,9 @@ absl::StatusOr<const se::CommandBuffer::Command*> AllGatherCmd::Record(
23472337
"AllGatherCmd requires collective parameters and cliques");
23482338
}
23492339

2350-
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives,
2351-
Thunk::GetGpuCollectives(execute_params));
2352-
23532340
TF_ASSIGN_OR_RETURN(
23542341
CommunicatorHandle comm_handle,
2355-
GetComm(collectives, *execute_params.collective_params,
2342+
GetComm(*execute_params.collective_params,
23562343
*execute_params.collective_cliques, config().replica_groups,
23572344
config().group_mode,
23582345
AsyncStreamKind::ASYNC_STREAM_KIND_COLLECTIVE)); // Use constant
@@ -2411,12 +2398,9 @@ CollectiveBroadcastCmd::Record(const Thunk::ExecuteParams& execute_params,
24112398
"CollectiveBroadcastCmd requires collective parameters and cliques");
24122399
}
24132400

2414-
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives,
2415-
Thunk::GetGpuCollectives(execute_params));
2416-
24172401
TF_ASSIGN_OR_RETURN(
24182402
CommunicatorHandle comm_handle,
2419-
GetComm(collectives, *execute_params.collective_params,
2403+
GetComm(*execute_params.collective_params,
24202404
*execute_params.collective_cliques, config().replica_groups,
24212405
config().group_mode,
24222406
AsyncStreamKind::ASYNC_STREAM_KIND_COLLECTIVE)); // Use constant
@@ -2476,12 +2460,9 @@ absl::StatusOr<const se::CommandBuffer::Command*> CollectivePermuteCmd::Record(
24762460
"CollectivePermuteCmd requires collective parameters and cliques");
24772461
}
24782462

2479-
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives,
2480-
Thunk::GetGpuCollectives(execute_params));
2481-
24822463
TF_ASSIGN_OR_RETURN(
24832464
CommunicatorHandle comm_handle,
2484-
GetComm(collectives, *execute_params.collective_params,
2465+
GetComm(*execute_params.collective_params,
24852466
*execute_params.collective_cliques, config().replica_groups,
24862467
config().group_mode,
24872468
AsyncStreamKind::ASYNC_STREAM_KIND_COLLECTIVE)); // Use constant

xla/backends/gpu/runtime/nvshmem_collective_thunk.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,11 @@ absl::StatusOr<xla::gpu::GpuCollectives*> GetNvshmemCollectivesFromRegistry() {
9191

9292
absl::Status NvshmemCollectiveThunk::Prepare(const PrepareParams& params) {
9393
TF_RET_CHECK(params.collective_params != nullptr);
94-
TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params));
9594
TF_ASSIGN_OR_RETURN(
9695
GpuCliqueKey clique_key,
97-
GetGpuCliqueKey(collectives, *params.collective_params,
98-
config().replica_groups, config().group_mode,
99-
GetAsyncStreamKind(), /*use_nccl= */ false));
96+
GetGpuCliqueKey(*params.collective_params, config().replica_groups,
97+
config().group_mode, GetAsyncStreamKind(),
98+
/*use_nccl= */ false));
10099
return params.clique_requests->RequestClique(clique_key);
101100
}
102101

xla/backends/gpu/runtime/thunk.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -309,14 +309,6 @@ ThunkMetadataListProto GetMetadataListProtoFromThunkGraph(
309309
return metadata_list_proto;
310310
}
311311

312-
absl::StatusOr<GpuCollectives* absl_nonnull> Thunk::GetGpuCollectives(
313-
const CollectiveParams& params) {
314-
if (params.collectives == nullptr) {
315-
return Internal("Collectives API is not provided");
316-
}
317-
return params.collectives;
318-
}
319-
320312
ThunkInfoProto Thunk::ThunkInfo::ToProto() const {
321313
ThunkInfoProto proto;
322314
proto.set_profile_annotation(profile_annotation);

xla/backends/gpu/runtime/thunk.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -447,22 +447,6 @@ class Thunk {
447447
return absl::OkStatus();
448448
}
449449

450-
// A helper function to get the `GpuCollectives*` pointer from the
451-
// CollectiveParams.
452-
static absl::StatusOr<GpuCollectives* absl_nonnull> GetGpuCollectives(
453-
CollectiveParams const& params);
454-
455-
// A helper function to get the `GpuCollectives*` pointer from the
456-
// thunk parameters. Returns an error if collectives API is not provided.
457-
template <typename Params>
458-
static absl::StatusOr<GpuCollectives* absl_nonnull> GetGpuCollectives(
459-
const Params& params) {
460-
if (params.collective_params == nullptr) {
461-
return Internal("Collective params are not provided");
462-
}
463-
return GetGpuCollectives(*params.collective_params);
464-
}
465-
466450
// Serializes the thunk into a `ThunkProto`.
467451
virtual absl::StatusOr<ThunkProto> ToProto() const;
468452

0 commit comments

Comments
 (0)