Skip to content

Commit 3865f67

Browse files
committed
remove debug print, attemp to derive num partitions from lowering
1 parent e2d157b commit 3865f67

7 files changed

+6
-128
lines changed

torch_xla/csrc/lowering_context.cpp

-45
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ LoweringContext::LoweringContext(const std::string& name,
9393
torch::lazy::BackendDevice device)
9494
: torch::lazy::LoweringContext(name, device),
9595
builder_(name),
96-
num_computation_partitions_(1),
9796
stack_frame_index_builder_(std::make_shared<StackFrameIndexBuilder>()) {}
9897

9998
LoweringContext::LoweringContext(
@@ -102,7 +101,6 @@ LoweringContext::LoweringContext(
102101
torch::lazy::Util::EmissionMap emit_status)
103102
: torch::lazy::LoweringContext(name, device, {}, emit_status),
104103
builder_(name),
105-
num_computation_partitions_(1),
106104
stack_frame_index_builder_(std::make_shared<StackFrameIndexBuilder>()) {
107105
for (auto node : post_order) {
108106
LowerNode(node);
@@ -133,7 +131,6 @@ xla::XlaOp LoweringContext::GetParameter(
133131
xla::OpSharding sharding = data->GetSharding();
134132
xla::XlaScopedShardingAssignment scoped_sharding(builder(), sharding);
135133
param = xla::Parameter(builder(), param_index, shape, param_name);
136-
UpdateNumPartitions(param);
137134
} else {
138135
param = xla::Parameter(builder(), param_index, shape, param_name);
139136
}
@@ -257,28 +254,6 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) {
257254
mutable_dims->Set(dim, kUnboundedSize);
258255
}
259256
}
260-
std::for_each(result_ops.begin(), result_ops.end(),
261-
[this](xla::XlaOp xla_op) {
262-
UpdateNumPartitions(xla_op); // Calling the member function
263-
});
264-
// for (auto xla_op : result_ops) {
265-
// UpdateNumPartitions(xla_op);
266-
// // std::optional<OpSharding> op_sharding =
267-
// // ConsumeValue(builder()->GetOpSharding(xla_op));
268-
// // if (op_sharding.has_value()) {
269-
// // size_t curr_num_partitions =
270-
// // op_sharding.value().tile_assignment_devices().size();
271-
// // if (num_computation_partitions_ != 1) {
272-
// // XLA_CHECK_EQ(curr_num_partitions, num_computation_partitions_)
273-
// <<
274-
// // "Number of partitions must be the same for all ops in a HLO
275-
// graph.";
276-
// // continue;
277-
// // }
278-
// // num_computation_partitions_ =
279-
// op_sharding.value().tile_assignment_devices().size();
280-
// // }
281-
// }
282257
} catch (const std::exception& ex) {
283258
ReportBuilderError(node, ex.what());
284259
}
@@ -349,24 +324,4 @@ torch::lazy::ComputationPtr LoweringContext::Build() {
349324
builder_.name(), std::move(xla_computation), device_);
350325
}
351326

352-
void LoweringContext::UpdateNumPartitions(const xla::XlaOp& op) {
353-
std::optional<xla::OpSharding> op_sharding =
354-
ConsumeValue(builder()->GetOpSharding(op));
355-
if (op_sharding.has_value()) {
356-
size_t curr_num_partitions =
357-
op_sharding.value().tile_assignment_devices().size();
358-
if (curr_num_partitions == 0) {
359-
return;
360-
}
361-
if (num_computation_partitions_ != 1) {
362-
XLA_CHECK_EQ(curr_num_partitions, num_computation_partitions_)
363-
<< "Number of partitions must be the same for all ops in a HLO "
364-
"graph.";
365-
return;
366-
}
367-
std::cout << "curr_num_partitions: " << curr_num_partitions << std::endl;
368-
num_computation_partitions_ = curr_num_partitions;
369-
}
370-
}
371-
372327
} // namespace torch_xla

torch_xla/csrc/lowering_context.h

-10
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,10 @@ class LoweringContext : public torch::lazy::LoweringContext {
113113
return emitted_outputs_;
114114
}
115115

116-
size_t GetComputationNumPartitions() const {
117-
return num_computation_partitions_;
118-
}
119-
120116
// Return stack frame id
121117
int64_t AddStackFrameLocation(const torch::lazy::SourceLocation& source,
122118
int64_t parent_id);
123119

124-
protected:
125-
// Update the number of partitions from a XlaOp.
126-
void UpdateNumPartitions(const xla::XlaOp& op);
127-
128120
private:
129121
struct Parameter {
130122
xla::XlaOp param;
@@ -141,8 +133,6 @@ class LoweringContext : public torch::lazy::LoweringContext {
141133
std::vector<xla::XlaOp> root_tuple_;
142134
OutputMap<xla::XlaOp> emitted_outputs_;
143135
std::string name_;
144-
// Number of partitions of the lowered XLA computation.
145-
size_t num_computation_partitions_;
146136

147137
std::shared_ptr<StackFrameIndexBuilder> stack_frame_index_builder_;
148138
}; // namespace torch_xla

torch_xla/csrc/runtime/computation_client.h

-3
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ class ComputationClient {
225225
xla::XlaComputation computation, std::string compilation_device,
226226
std::vector<std::string> devices, const xla::Shape* output_shape,
227227
bool parameter_is_tupled_arguments = false, bool is_sharded = false,
228-
size_t computation_num_partitions = 1,
229228
bool allow_spmd_sharding_propagation_to_output = true,
230229
bool use_auto_spmd_partitioning = false,
231230
std::vector<int64_t> auto_spmd_mesh_shape = {},
@@ -236,7 +235,6 @@ class ComputationClient {
236235
output_shape(output_shape),
237236
parameter_is_tupled_arguments(parameter_is_tupled_arguments),
238237
is_sharded(is_sharded),
239-
computation_num_partitions(computation_num_partitions),
240238
allow_spmd_sharding_propagation_to_output(
241239
allow_spmd_sharding_propagation_to_output),
242240
use_auto_spmd_partitioning(use_auto_spmd_partitioning),
@@ -250,7 +248,6 @@ class ComputationClient {
250248
const xla::Shape* output_shape = nullptr;
251249
bool parameter_is_tupled_arguments;
252250
bool is_sharded;
253-
size_t computation_num_partitions = 1;
254251
bool allow_spmd_sharding_propagation_to_output;
255252
bool use_auto_spmd_partitioning;
256253
std::vector<int64_t> auto_spmd_mesh_shape;

torch_xla/csrc/runtime/pjrt_computation_client.cc

+1-23
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,6 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice(
334334
std::shared_ptr<PjRtComputationClient::PjRtData>
335335
PjRtComputationClient::ReplicateShardedData(
336336
const ComputationClient::DataPtr& handle) {
337-
std::cout << "PjRtComputationClient::ReplicateShardedData" << std::endl;
338337
if (auto unsharded_data = std::dynamic_pointer_cast<PjRtData>(handle)) {
339338
return unsharded_data;
340339
} else if (auto sharded_data =
@@ -348,9 +347,7 @@ PjRtComputationClient::ReplicateShardedData(
348347
}
349348
xla::XlaBuilder builder("ReplicateShardedData");
350349
xla::Shape shape = sharded_data->shape();
351-
xla::OpSharding sharding = sharded_data->GetSharding();
352-
builder.SetSharding(sharding);
353-
size_t num_partitions = sharding.tile_assignment_devices().size();
350+
builder.SetSharding(sharded_data->GetSharding());
354351

355352
// perform a simple identity calculation to reassemble the input as
356353
// replicated output.
@@ -374,7 +371,6 @@ PjRtComputationClient::ReplicateShardedData(
374371
GetCompilationDevices(device, {}), &shape,
375372
/*should_wrap_parameter=*/false,
376373
/*is_sharded=*/true,
377-
/*computation_num_partitions*/ num_partitions,
378374
/*allow_spmd_sharding_propagation_to_output=*/false});
379375
std::vector<
380376
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
@@ -541,7 +537,6 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(
541537

542538
std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
543539
std::vector<ComputationClient::CompileInstance> instances) {
544-
std::cout << "in compile" << std::endl;
545540
auto metrics_fn = CompileMetric;
546541
if (instances[0].eager_mode) {
547542
metrics_fn = EagerCompileMetric;
@@ -551,9 +546,7 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
551546
tsl::profiler::TraceMeLevel::kInfo);
552547
std::vector<ComputationClient::ComputationPtr> computations;
553548

554-
std::cout << "instances.size(): " << instances.size() << std::endl;
555549
for (auto& instance : instances) {
556-
std::cout << "instance devices " << instance.devices << std::endl;
557550
xla::CompileOptions compile_options;
558551
if (instance.is_sharded) {
559552
// TODO(yeounoh) multi-host, multi-slice configurations
@@ -570,8 +563,6 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
570563
if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) {
571564
num_partitions = GetNumLocalDevices();
572565
}
573-
// num_partitions = static_cast<int>(instance.computation_num_partitions);
574-
std::cout << "num_partitions: " << num_partitions << std::endl;
575566
compile_options.executable_build_options.set_num_partitions(
576567
num_partitions);
577568
compile_options.executable_build_options.set_num_replicas(1);
@@ -668,7 +659,6 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
668659

669660
CreateCompileHandlesCounter()->AddValue(1);
670661
}
671-
std::cout << "finish compile" << std::endl;
672662
return computations;
673663
}
674664

@@ -720,7 +710,6 @@ PjRtComputationClient::ExecuteComputation(
720710
const ComputationClient::Computation& computation,
721711
absl::Span<const ComputationClient::DataPtr> arguments,
722712
const std::string& device, const ExecuteComputationOptions& options) {
723-
std::cout << "in execute" << std::endl;
724713
// Shared ownership of the timed section ensures that it will only get logged
725714
// once both `ExecuteComputation` and the async work in `ExecuteSharded` are
726715
// complete; a copy is held from the lambda that releases it when done.
@@ -788,7 +777,6 @@ PjRtComputationClient::ExecuteComputation(
788777
CreateDataHandlesCounter()->AddValue(datas.size());
789778

790779
TF_VLOG(1) << "Returning " << datas.size() << " results";
791-
std::cout << "finish execute" << std::endl;
792780
return datas;
793781
}
794782

@@ -798,10 +786,6 @@ PjRtComputationClient::ExecuteReplicated(
798786
absl::Span<const ComputationClient::DataPtr> arguments,
799787
absl::Span<const std::string> devices,
800788
const ExecuteReplicatedOptions& options) {
801-
std::cout << "in execute replicated" << std::endl;
802-
for (auto d : devices) {
803-
std::cout << "device: " << d << std::endl;
804-
}
805789
// Shared ownership of the timed section ensures that it will only get logged
806790
// once both `ExecuteReplicated` and the async work in `Execute` are
807791
// complete; a copy is held from the lambda that releases it when done.
@@ -939,7 +923,6 @@ PjRtComputationClient::ExecuteReplicated(
939923
}
940924

941925
TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs.";
942-
std::cout << "finish execute replicated" << std::endl;
943926
return data_handles;
944927
}
945928

@@ -1002,17 +985,12 @@ xla::PjRtDevice* PjRtComputationClient::StringToPjRtDevice(
1002985

1003986
void PjRtComputationClient::WaitDeviceOps(
1004987
absl::Span<const std::string> devices) {
1005-
std::cout << "in wait device ops" << std::endl;
1006-
for (auto d : devices) {
1007-
std::cout << "device: " << d << std::endl;
1008-
}
1009988
TF_VLOG(3) << "Waiting for " << absl::StrJoin(devices, ", ");
1010989
operation_manager_.WaitForDevices(
1011990
devices.empty()
1012991
? (UseVirtualDevice() ? std::vector<std::string>({spmd_device_str})
1013992
: GetLocalDevices())
1014993
: devices);
1015-
std::cout << "finish wait device ops" << std::endl;
1016994
}
1017995

1018996
std::map<std::string, Metric> PjRtComputationClient::GetMetrics() const {

torch_xla/csrc/xla_graph_executor.cpp

+2-12
Original file line numberDiff line numberDiff line change
@@ -1391,16 +1391,12 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
13911391
// Always execute sharded when running in SPMD mode
13921392
bool is_sharded = (coll.device == GetVirtualDevice()) || UseVirtualDevice();
13931393
// Annotate HLO sharding selectively in the compuation.
1394-
bool is_sharded_2 = ShardingUtil::SetHloSharding(&lowering_ctx);
1395-
1396-
std::cout << "is_sharded_2: " << is_sharded_2 << std::endl;
1394+
ShardingUtil::SetHloSharding(&lowering_ctx);
13971395

13981396
SetBufferDonors(&lowering_ctx, buffer_donor_indices);
13991397

14001398
xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla());
14011399
xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape());
1402-
size_t computation_num_partitions =
1403-
lowering_ctx.GetComputationNumPartitions();
14041400

14051401
// TODO(yeounoh) enable wrapping with auto-sharding.
14061402
bool should_wrap_parameter =
@@ -1426,15 +1422,11 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
14261422
program_shape.result(), static_cast<XlaDeviceType>(coll.device.type()));
14271423

14281424
std::vector<runtime::ComputationClient::CompileInstance> instances;
1429-
std::cout << "computation_num_partitions: " << computation_num_partitions
1430-
<< std::endl;
14311425
instances.emplace_back(std::move(computation), coll.device.toString(),
14321426
runtime::GetComputationClient()->GetCompilationDevices(
14331427
coll.device.toString(), devices),
1434-
&shape, should_wrap_parameter, is_sharded,
1435-
computation_num_partitions);
1428+
&shape, should_wrap_parameter, is_sharded);
14361429
instances.front().eager_mode = UseEagerMode();
1437-
instances.front().computation_num_partitions = computation_num_partitions;
14381430
if (use_autosharding) {
14391431
TF_VLOG(5) << "use_auto_spmd_partitioning is set.";
14401432
TF_CHECK(is_sharded) << "Auto-sharding pass requires SPMD mode.";
@@ -1463,8 +1455,6 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
14631455
TF_VLOG(3) << "Compiling IR graph hash "
14641456
<< torch::lazy::HashToString(coll.hash) << " on device "
14651457
<< coll.device << " ...";
1466-
std::cout << "check instance num partitions"
1467-
<< instances.front().computation_num_partitions << std::endl;
14681458
std::vector<std::shared_ptr<runtime::ComputationClient::Computation>>
14691459
computations =
14701460
runtime::GetComputationClient()->Compile(std::move(instances));

torch_xla/csrc/xla_sharding_util.cpp

+3-31
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,6 @@ bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) {
192192
XlaBuilderFriend::GetInstruction(elem.second);
193193
const std::shared_ptr<xla::OpSharding> sharding =
194194
xla_node->GetSharding(elem.first.index);
195-
if (sharding != nullptr) {
196-
std::cout << "check opsharding " << sharding->DebugString() << std::endl;
197-
}
198195
if (sharding != nullptr && sharding->type() != xla::OpSharding::UNKNOWN) {
199196
*instruction->mutable_sharding() = *sharding;
200197
is_sharded = true;
@@ -375,33 +372,15 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices(
375372
shard_indices[i] = std::make_pair(global_ordinal, indices);
376373
}
377374
} else if (sharding.type() == xla::OpSharding::OTHER) {
378-
std::vector<int64_t> tile_assignment_devices(
379-
sharding.tile_assignment_devices().begin(),
380-
sharding.tile_assignment_devices().end());
381-
size_t num_local_devices =
382-
runtime::GetComputationClient()->GetNumLocalDevices();
383-
size_t num_global_devices =
384-
runtime::GetComputationClient()->GetNumGlobalDevices();
385-
// XLA_CHECK(tile_assignment_devices.size() == 0 ||
386-
// tile_assignment_devices.size() == num_global_devices ||
387-
// tile_assignment_devices.size() == num_local_devices)
388-
// << "Number of tile_assignment_devices must be the number of global "
389-
// "devices or local devices, or 0, got unexpected size of "
390-
// << tile_assignment_devices.size();
391375
size_t num_tiles =
392376
std::accumulate(sharding.tile_assignment_dimensions().begin(),
393377
sharding.tile_assignment_dimensions().end(), 1,
394378
[](int a, int b) { return a * b; });
395-
std::cout << "Num local devices " << num_local_devices << std::endl;
396-
std::cout << "Num tile assignment size " << tile_assignment_devices.size()
397-
<< std::endl;
398379
std::unordered_map<int, int> device_index =
399380
build_index_map(devices, num_tiles);
400-
std::cout << "Check device_index " << std::endl;
401-
for (const auto& pair : device_index) {
402-
std::cout << "Key: " << pair.first << ", Value: " << pair.second
403-
<< std::endl;
404-
}
381+
std::vector<int64_t> tile_assignment_devices(
382+
sharding.tile_assignment_devices().begin(),
383+
sharding.tile_assignment_devices().end());
405384
if (!sharding.iota_reshape_dims().empty()) {
406385
auto tileAssignment = xla::TileAssignment(
407386
sharding.tile_assignment_dimensions(), sharding.iota_reshape_dims(),
@@ -411,10 +390,7 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices(
411390
}
412391
for (size_t i = 0; i < tile_assignment_devices.size(); i++) {
413392
int64_t core = tile_assignment_devices[i];
414-
std::cout << "Check core " << core << std::endl;
415393
if (device_index.find(core) == device_index.end()) {
416-
std::cout << "current core " << core << " is not in device_index"
417-
<< std::endl;
418394
// Skip any shards whose device is not part of the `devices` list.
419395
continue;
420396
}
@@ -464,8 +440,6 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices(
464440
std::vector<at::Tensor> ShardingUtil::ShardTensor(
465441
const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings,
466442
const std::vector<std::string>& devices, bool padded) {
467-
std::cout << "ShardingUtil::ShardTensor check devices " << devices
468-
<< std::endl;
469443
xla::OpSharding sharding;
470444
bool minibatch = false;
471445
if (shardings != nullptr) {
@@ -496,8 +470,6 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor(
496470
std::back_inserter(shard_indices),
497471
[](auto& pair) { return pair.second; });
498472
}
499-
std::cout << "ShardingUtil::ShardTensor check shard_indices: "
500-
<< shard_indices << std::endl;
501473

502474
for (size_t i = 0; i < shard_indices.size(); i++) {
503475
at::Tensor shard = tensor.index(

torch_xla/distributed/spmd/xla_sharding.py

-4
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,6 @@ def get_op_sharding(self,
130130

131131
tile_assignment, group_assignment, replication_groups, sharding_type = self._get_op_sharding_args(
132132
partition_spec)
133-
print(f"check tile_assignment: {tile_assignment}")
134-
print(f"check group_assignment: {group_assignment}")
135-
print(f"check replication_groups: {replication_groups}")
136-
print(f"check sharding_type: {sharding_type}")
137133
return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment,
138134
replication_groups, sharding_type)
139135

0 commit comments

Comments
 (0)