@@ -334,7 +334,6 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice(
334
334
std::shared_ptr<PjRtComputationClient::PjRtData>
335
335
PjRtComputationClient::ReplicateShardedData (
336
336
const ComputationClient::DataPtr& handle) {
337
- std::cout << " PjRtComputationClient::ReplicateShardedData" << std::endl;
338
337
if (auto unsharded_data = std::dynamic_pointer_cast<PjRtData>(handle)) {
339
338
return unsharded_data;
340
339
} else if (auto sharded_data =
@@ -348,9 +347,7 @@ PjRtComputationClient::ReplicateShardedData(
348
347
}
349
348
xla::XlaBuilder builder (" ReplicateShardedData" );
350
349
xla::Shape shape = sharded_data->shape ();
351
- xla::OpSharding sharding = sharded_data->GetSharding ();
352
- builder.SetSharding (sharding);
353
- size_t num_partitions = sharding.tile_assignment_devices ().size ();
350
+ builder.SetSharding (sharded_data->GetSharding ());
354
351
355
352
// perform a simple identity calculation to reassemble the input as
356
353
// replicated output.
@@ -374,7 +371,6 @@ PjRtComputationClient::ReplicateShardedData(
374
371
GetCompilationDevices (device, {}), &shape,
375
372
/* should_wrap_parameter=*/ false ,
376
373
/* is_sharded=*/ true ,
377
- /* computation_num_partitions*/ num_partitions,
378
374
/* allow_spmd_sharding_propagation_to_output=*/ false });
379
375
std::vector<
380
376
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
@@ -541,7 +537,6 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(
541
537
542
538
std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile (
543
539
std::vector<ComputationClient::CompileInstance> instances) {
544
- std::cout << " in compile" << std::endl;
545
540
auto metrics_fn = CompileMetric;
546
541
if (instances[0 ].eager_mode ) {
547
542
metrics_fn = EagerCompileMetric;
@@ -551,9 +546,7 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
551
546
tsl::profiler::TraceMeLevel::kInfo );
552
547
std::vector<ComputationClient::ComputationPtr> computations;
553
548
554
- std::cout << " instances.size(): " << instances.size () << std::endl;
555
549
for (auto & instance : instances) {
556
- std::cout << " instance devices " << instance.devices << std::endl;
557
550
xla::CompileOptions compile_options;
558
551
if (instance.is_sharded ) {
559
552
// TODO(yeounoh) multi-host, multi-slice configurations
@@ -570,8 +563,6 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
570
563
if (runtime::sys_util::GetEnvBool (" XLA_USE_LOCAL_SPMD" , false )) {
571
564
num_partitions = GetNumLocalDevices ();
572
565
}
573
- // num_partitions = static_cast<int>(instance.computation_num_partitions);
574
- std::cout << " num_partitions: " << num_partitions << std::endl;
575
566
compile_options.executable_build_options .set_num_partitions (
576
567
num_partitions);
577
568
compile_options.executable_build_options .set_num_replicas (1 );
@@ -668,7 +659,6 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
668
659
669
660
CreateCompileHandlesCounter ()->AddValue (1 );
670
661
}
671
- std::cout << " finish compile" << std::endl;
672
662
return computations;
673
663
}
674
664
@@ -720,7 +710,6 @@ PjRtComputationClient::ExecuteComputation(
720
710
const ComputationClient::Computation& computation,
721
711
absl::Span<const ComputationClient::DataPtr> arguments,
722
712
const std::string& device, const ExecuteComputationOptions& options) {
723
- std::cout << " in execute" << std::endl;
724
713
// Shared ownership of the timed section ensures that it will only get logged
725
714
// once both `ExecuteComputation` and the async work in `ExecuteSharded` are
726
715
// complete; a copy is held from the lambda that releases it when done.
@@ -788,7 +777,6 @@ PjRtComputationClient::ExecuteComputation(
788
777
CreateDataHandlesCounter ()->AddValue (datas.size ());
789
778
790
779
TF_VLOG (1 ) << " Returning " << datas.size () << " results" ;
791
- std::cout << " finish execute" << std::endl;
792
780
return datas;
793
781
}
794
782
@@ -798,10 +786,6 @@ PjRtComputationClient::ExecuteReplicated(
798
786
absl::Span<const ComputationClient::DataPtr> arguments,
799
787
absl::Span<const std::string> devices,
800
788
const ExecuteReplicatedOptions& options) {
801
- std::cout << " in execute replicated" << std::endl;
802
- for (auto d : devices) {
803
- std::cout << " device: " << d << std::endl;
804
- }
805
789
// Shared ownership of the timed section ensures that it will only get logged
806
790
// once both `ExecuteReplicated` and the async work in `Execute` are
807
791
// complete; a copy is held from the lambda that releases it when done.
@@ -939,7 +923,6 @@ PjRtComputationClient::ExecuteReplicated(
939
923
}
940
924
941
925
TF_VLOG (1 ) << " Returning " << data_handles.size () << " sharded outputs." ;
942
- std::cout << " finish execute replicated" << std::endl;
943
926
return data_handles;
944
927
}
945
928
@@ -1002,17 +985,12 @@ xla::PjRtDevice* PjRtComputationClient::StringToPjRtDevice(
1002
985
1003
986
void PjRtComputationClient::WaitDeviceOps (
1004
987
absl::Span<const std::string> devices) {
1005
- std::cout << " in wait device ops" << std::endl;
1006
- for (auto d : devices) {
1007
- std::cout << " device: " << d << std::endl;
1008
- }
1009
988
TF_VLOG (3 ) << " Waiting for " << absl::StrJoin (devices, " , " );
1010
989
operation_manager_.WaitForDevices (
1011
990
devices.empty ()
1012
991
? (UseVirtualDevice () ? std::vector<std::string>({spmd_device_str})
1013
992
: GetLocalDevices ())
1014
993
: devices);
1015
- std::cout << " finish wait device ops" << std::endl;
1016
994
}
1017
995
1018
996
std::map<std::string, Metric> PjRtComputationClient::GetMetrics () const {
0 commit comments