Skip to content

Commit 231badc

Browse files
olegshyshkovGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Derive replica and partition counts from HloModule in ExecuteReplicated.
PiperOrigin-RevId: 837843506
1 parent 2039fa8 commit 231badc

File tree

3 files changed

+90
-171
lines changed

3 files changed

+90
-171
lines changed

xla/tests/collective_ops_e2e_test_base.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,10 @@ CollectiveOpsE2ETestBase::ExecuteReplicated(
180180
absl::StatusOr<CollectiveOpsE2ETestBase::ExecutionResult>
181181
CollectiveOpsE2ETestBase::ExecuteReplicated(
182182
std::unique_ptr<HloModule> module,
183-
const std::vector<std::vector<Literal*>> arguments, int64_t num_replicas,
184-
int64_t num_partitions, bool run_hlo_passes) {
183+
const std::vector<std::vector<Literal*>> arguments, bool run_hlo_passes) {
184+
int64_t num_replicas = module->config().replica_count();
185+
int64_t num_partitions = module->config().num_partitions();
186+
185187
CHECK(num_replicas > 0 && "expect at least one replica");
186188
CHECK(num_partitions > 0 && "expect at least one partition");
187189
CHECK(num_replicas == arguments.size() &&

xla/tests/collective_ops_e2e_test_base.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ class CollectiveOpsE2ETestBase : public HloHardwareIndependentTestBase {
7777

7878
absl::StatusOr<ExecutionResult> ExecuteReplicated(
7979
std::unique_ptr<HloModule> module,
80-
std::vector<std::vector<Literal*>> arguments, int64_t num_replicas,
81-
int64_t num_partitions, bool run_hlo_passes = true);
80+
std::vector<std::vector<Literal*>> arguments, bool run_hlo_passes = true);
8281

8382
const se::GpuComputeCapability& Capability() {
8483
return hlo_runner_->backend()

0 commit comments

Comments
 (0)