From 2d2f08ab95452fc0e36e12339db163652aed5050 Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Sun, 9 Mar 2025 08:01:01 +0000
Subject: [PATCH 01/16] update mesh/xla_sharding python api for local spmd

---
 torch_xla/distributed/spmd/xla_sharding.py | 18 ++++++++++++++----
 torch_xla/runtime.py                       |  2 +-
 2 files changed, 15 insertions(+), 5 deletions(-)

diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py
index a1cd9540fd1..de2714ad249 100644
--- a/torch_xla/distributed/spmd/xla_sharding.py
+++ b/torch_xla/distributed/spmd/xla_sharding.py
@@ -68,7 +68,7 @@ def __init__(self,
     self.device_ids = device_ids
     self.mesh_shape = mesh_shape
     self.axis_names = axis_names
-    assert all(d < self.size() for d in device_ids)
+    # assert all(d < self.size() for d in device_ids)
 
   def size(self):
     return np.prod(self.mesh_shape)
@@ -127,6 +127,10 @@ def get_op_sharding(self,
 
     tile_assignment, group_assignment, replication_groups, sharding_type = self._get_op_sharding_args(
         partition_spec)
+    print(f"check tile_assignment: {tile_assignment}")
+    print(f"check group_assignment: {group_assignment}")
+    print(f"check replication_groups: {replication_groups}")
+    print(f"check sharding_type: {sharding_type}")
     return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment,
                                       replication_groups, sharding_type)
 
@@ -377,6 +381,11 @@ def _get_sharding_type(partition_spec: Tuple[Union[int, None]],
   return sharding_type
 
 
+def _normalize_logical_mesh(device_mesh: np.ndarray) -> np.ndarray:
+  device_id_min = np.min(device_mesh)
+  return device_mesh.copy() - device_id_min
+
+
 def _get_tile_assignment(
     mesh: Mesh, partition_spec: Tuple[Union[Tuple[int], int,
                                             None]]) -> np.ndarray:
@@ -393,8 +402,8 @@ def _get_tile_assignment(
   tiled_dims = [x for x in partition_spec if x is not None]
   permutation = np.hstack(tiled_dims).tolist() if tiled_dims else []
   missing_axes = sorted(set(range(len(mesh.shape()))) - set(permutation))
-  tile_assignment = mesh.get_logical_mesh().transpose(permutation +
-                                                      missing_axes)
+  tile_assignment = _normalize_logical_mesh(
+      mesh.get_logical_mesh()).transpose(permutation + missing_axes)
 
   # For any tuples in the partition_spec, the grouped axes will be adjacent
   # after the permutation. Combine these dimensions into a single axis.
@@ -548,8 +557,9 @@ def mark_sharding(
       >>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel
   """
   num_devices = xr.global_runtime_device_count()
+  num_local_devices = xr.addressable_runtime_device_count()
   assert num_devices > 0, "This requires XLA supported device(s)."
-  assert mesh.size() == num_devices, \
+  assert mesh.size() == num_devices or mesh.size() == num_local_devices, \
     f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
   # We only allow fully specified `partition_spec` to be applicable, as opposed
   # to filling in the unspecified replicated dims. Fully specified `partiion_spec`
diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py
index 1946ae05a52..a17b1c57e3b 100644
--- a/torch_xla/runtime.py
+++ b/torch_xla/runtime.py
@@ -212,7 +212,7 @@ def global_runtime_device_attributes() -> List[Dict[str, object]]:
 @functools.lru_cache()
 def global_runtime_device_count() -> int:
   """Returns the total number of runtime devices across all processes/hosts, especially useful for SPMD."""
-  return len(torch_xla._XLAC._xla_get_all_runtime_devices())
+  return torch_xla._XLAC._xla_num_global_devices()
 
 
 def addressable_runtime_device_count() -> int:

From c4aa8542eb00fda59a9281043d5ec2dce7517820 Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Sun, 9 Mar 2025 08:02:57 +0000
Subject: [PATCH 02/16] make local spmd working

---
 torch_xla/csrc/init_python_bindings.cpp       |  7 ++-
 torch_xla/csrc/lowering_context.cpp           | 42 +++++++++++++++++
 torch_xla/csrc/lowering_context.h             | 10 ++++
 torch_xla/csrc/runtime/computation_client.h   |  7 ++-
 .../csrc/runtime/ifrt_computation_client.cc   |  6 ++-
 .../csrc/runtime/ifrt_computation_client.h    |  4 +-
 .../csrc/runtime/pjrt_computation_client.cc   | 47 ++++++++++++++++---
 .../csrc/runtime/pjrt_computation_client.h    |  4 +-
 torch_xla/csrc/tensor_impl.cpp                |  2 +-
 torch_xla/csrc/xla_graph_executor.cpp         | 20 ++++++--
 torch_xla/csrc/xla_sharding_util.cpp          | 34 ++++++++++++--
 11 files changed, 161 insertions(+), 22 deletions(-)

diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp
index 98012ea2d35..020979a4880 100644
--- a/torch_xla/csrc/init_python_bindings.cpp
+++ b/torch_xla/csrc/init_python_bindings.cpp
@@ -1482,7 +1482,7 @@ void InitXlaModuleBindings(py::module m) {
     if (UseVirtualDevice()) {
       return 1;
     } else {
-      return runtime::GetComputationClient()->GetNumDevices();
+      return runtime::GetComputationClient()->GetNumLocalDevices();
     }
   });
   m.def("_xla_get_all_devices", []() {
@@ -1500,13 +1500,16 @@ void InitXlaModuleBindings(py::module m) {
   m.def("_xla_get_runtime_devices",
         []() { return runtime::GetComputationClient()->GetLocalDevices(); });
   m.def("_xla_num_runtime_devices", []() -> int64_t {
-    return runtime::GetComputationClient()->GetNumDevices();
+    return runtime::GetComputationClient()->GetNumLocalDevices();
   });
   m.def("_xla_get_all_runtime_devices", []() {
     std::vector<std::string> all_devices =
         runtime::GetComputationClient()->GetAllDevices();
     return all_devices;
   });
+  m.def("_xla_num_global_devices", []() -> int64_t {
+    return runtime::GetComputationClient()->GetNumGlobalDevices();
+  });
   m.def(
       "_xla_real_devices",
       [](const std::optional<std::vector<std::string>> devices) {
diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp
index 6c2906dc724..a004be88c54 100644
--- a/torch_xla/csrc/lowering_context.cpp
+++ b/torch_xla/csrc/lowering_context.cpp
@@ -93,6 +93,7 @@ LoweringContext::LoweringContext(const std::string& name,
                                  torch::lazy::BackendDevice device)
     : torch::lazy::LoweringContext(name, device),
       builder_(name),
+      num_computation_partitions_(1),
       stack_frame_index_builder_(std::make_shared<StackFrameIndexBuilder>()) {}
 
 LoweringContext::LoweringContext(
@@ -101,6 +102,7 @@ LoweringContext::LoweringContext(
     torch::lazy::Util::EmissionMap emit_status)
     : torch::lazy::LoweringContext(name, device, {}, emit_status),
       builder_(name),
+      num_computation_partitions_(1),
       stack_frame_index_builder_(std::make_shared<StackFrameIndexBuilder>()) {
   for (auto node : post_order) {
     LowerNode(node);
@@ -131,6 +133,7 @@ xla::XlaOp LoweringContext::GetParameter(
       xla::OpSharding sharding = data->GetSharding();
       xla::XlaScopedShardingAssignment scoped_sharding(builder(), sharding);
       param = xla::Parameter(builder(), param_index, shape, param_name);
+      UpdateNumPartitions(param);
     } else {
       param = xla::Parameter(builder(), param_index, shape, param_name);
     }
@@ -254,6 +257,28 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) {
         mutable_dims->Set(dim, kUnboundedSize);
       }
     }
+    std::for_each(result_ops.begin(), result_ops.end(),
+                  [this](xla::XlaOp xla_op) {
+                    UpdateNumPartitions(xla_op);  // Calling the member function
+                  });
+    // for (auto xla_op : result_ops) {
+    //   UpdateNumPartitions(xla_op);
+    //   // std::optional<OpSharding> op_sharding =
+    //   //   ConsumeValue(builder()->GetOpSharding(xla_op));
+    //   // if (op_sharding.has_value()) {
+    //   //   size_t curr_num_partitions =
+    //   //     op_sharding.value().tile_assignment_devices().size();
+    //   //   if (num_computation_partitions_ != 1) {
+    //   //     XLA_CHECK_EQ(curr_num_partitions, num_computation_partitions_)
+    //   <<
+    //   //       "Number of partitions must be the same for all ops in a HLO
+    //   graph.";
+    //   //     continue;
+    //   //   }
+    //   //   num_computation_partitions_ =
+    //   op_sharding.value().tile_assignment_devices().size();
+    //   // }
+    // }
   } catch (const std::exception& ex) {
     ReportBuilderError(node, ex.what());
   }
@@ -324,4 +349,21 @@ torch::lazy::ComputationPtr LoweringContext::Build() {
       builder_.name(), std::move(xla_computation), device_);
 }
 
+void LoweringContext::UpdateNumPartitions(const xla::XlaOp& op) {
+  std::optional<xla::OpSharding> op_sharding =
+      ConsumeValue(builder()->GetOpSharding(op));
+  if (op_sharding.has_value()) {
+    size_t curr_num_partitions =
+        op_sharding.value().tile_assignment_devices().size();
+    if (num_computation_partitions_ != 1) {
+      XLA_CHECK_EQ(curr_num_partitions, num_computation_partitions_)
+          << "Number of partitions must be the same for all ops in a HLO "
+             "graph.";
+      return;
+    }
+    std::cout << "curr_num_partitions: " << curr_num_partitions << std::endl;
+    num_computation_partitions_ = curr_num_partitions;
+  }
+}
+
 }  // namespace torch_xla
diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h
index cb4f0bc2d2f..fdaabb2b14d 100644
--- a/torch_xla/csrc/lowering_context.h
+++ b/torch_xla/csrc/lowering_context.h
@@ -113,10 +113,18 @@ class LoweringContext : public torch::lazy::LoweringContext {
     return emitted_outputs_;
   }
 
+  size_t GetComputationNumPartitions() const {
+    return num_computation_partitions_;
+  }
+
   // Return stack frame id
   int64_t AddStackFrameLocation(const torch::lazy::SourceLocation& source,
                                 int64_t parent_id);
 
+ protected:
+  // Update the number of partitions from a XlaOp.
+  void UpdateNumPartitions(const xla::XlaOp& op);
+
  private:
   struct Parameter {
     xla::XlaOp param;
@@ -133,6 +141,8 @@ class LoweringContext : public torch::lazy::LoweringContext {
   std::vector<xla::XlaOp> root_tuple_;
   OutputMap<xla::XlaOp> emitted_outputs_;
   std::string name_;
+  // Number of partitions of the lowered XLA computation.
+  size_t num_computation_partitions_;
 
   std::shared_ptr<StackFrameIndexBuilder> stack_frame_index_builder_;
 };  // namespace torch_xla
diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h
index b192d8d2e14..339d2a4f52c 100644
--- a/torch_xla/csrc/runtime/computation_client.h
+++ b/torch_xla/csrc/runtime/computation_client.h
@@ -225,6 +225,7 @@ class ComputationClient {
         xla::XlaComputation computation, std::string compilation_device,
         std::vector<std::string> devices, const xla::Shape* output_shape,
         bool parameter_is_tupled_arguments = false, bool is_sharded = false,
+        size_t computation_num_partitions = 1,
         bool allow_spmd_sharding_propagation_to_output = true,
         bool use_auto_spmd_partitioning = false,
         std::vector<int64_t> auto_spmd_mesh_shape = {},
@@ -235,6 +236,7 @@ class ComputationClient {
           output_shape(output_shape),
           parameter_is_tupled_arguments(parameter_is_tupled_arguments),
           is_sharded(is_sharded),
+          computation_num_partitions(computation_num_partitions),
           allow_spmd_sharding_propagation_to_output(
               allow_spmd_sharding_propagation_to_output),
           use_auto_spmd_partitioning(use_auto_spmd_partitioning),
@@ -248,6 +250,7 @@ class ComputationClient {
     const xla::Shape* output_shape = nullptr;
     bool parameter_is_tupled_arguments;
     bool is_sharded;
+    size_t computation_num_partitions;
     bool allow_spmd_sharding_propagation_to_output;
     bool use_auto_spmd_partitioning;
     std::vector<int64_t> auto_spmd_mesh_shape;
@@ -374,7 +377,9 @@ class ComputationClient {
 
   virtual std::intptr_t GetCudaStreamForDevice(int local_device_id) const = 0;
 
-  virtual size_t GetNumDevices() const = 0;
+  virtual size_t GetNumLocalDevices() const = 0;
+
+  virtual size_t GetNumGlobalDevices() const = 0;
 
   virtual std::vector<std::string> GetLocalDevices() const = 0;
 
diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc
index a197aec460e..11aaa1a0b8d 100644
--- a/torch_xla/csrc/runtime/ifrt_computation_client.cc
+++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc
@@ -613,10 +613,14 @@ IfrtComputationClient::ExecuteReplicated(
   return data_handles;
 }
 
-size_t IfrtComputationClient::GetNumDevices() const {
+size_t IfrtComputationClient::GetNumLocalDevices() const {
   return client_->addressable_device_count();
 }
 
+size_t IfrtComputationClient::GetNumGlobalDevices() const {
+  return client_->device_count();
+}
+
 std::string IfrtComputationClient::GetDefaultDevice() const {
   return IfrtDeviceToString(client_->addressable_devices()[0]);
 }
diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h
index 73b8e21c9f0..26135f65ab5 100644
--- a/torch_xla/csrc/runtime/ifrt_computation_client.h
+++ b/torch_xla/csrc/runtime/ifrt_computation_client.h
@@ -79,7 +79,9 @@ class IfrtComputationClient : public ComputationClient {
       absl::Span<const std::string> devices,
       const ExecuteReplicatedOptions& options) override;
 
-  size_t GetNumDevices() const override;
+  size_t GetNumLocalDevices() const override;
+
+  size_t GetNumGlobalDevices() const override;
 
   std::string GetDefaultDevice() const override;
 
diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc
index 749419f66cd..6bf6217c036 100644
--- a/torch_xla/csrc/runtime/pjrt_computation_client.cc
+++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc
@@ -334,6 +334,7 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice(
 std::shared_ptr<PjRtComputationClient::PjRtData>
 PjRtComputationClient::ReplicateShardedData(
     const ComputationClient::DataPtr& handle) {
+  std::cout << "PjRtComputationClient::ReplicateShardedData" << std::endl;
   if (auto unsharded_data = std::dynamic_pointer_cast<PjRtData>(handle)) {
     return unsharded_data;
   } else if (auto sharded_data =
@@ -347,7 +348,9 @@ PjRtComputationClient::ReplicateShardedData(
     }
     xla::XlaBuilder builder("ReplicateShardedData");
     xla::Shape shape = sharded_data->shape();
-    builder.SetSharding(sharded_data->GetSharding());
+    xla::OpSharding sharding = sharded_data->GetSharding();
+    builder.SetSharding(sharding);
+    size_t num_partitions = sharding.tile_assignment_devices().size();
 
     // perform a simple identity calculation to reassemble the input as
     // replicated output.
@@ -371,6 +374,7 @@ PjRtComputationClient::ReplicateShardedData(
                          GetCompilationDevices(device, {}), &shape,
                          /*should_wrap_parameter=*/false,
                          /*is_sharded=*/true,
+                         /*computation_num_partitions*/ num_partitions,
                          /*allow_spmd_sharding_propagation_to_output=*/false});
     std::vector<
         std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
@@ -537,6 +541,7 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(
 
 std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
     std::vector<ComputationClient::CompileInstance> instances) {
+  std::cout << "in compile" << std::endl;
   auto metrics_fn = CompileMetric;
   if (instances[0].eager_mode) {
     metrics_fn = EagerCompileMetric;
@@ -546,7 +551,9 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
                                   tsl::profiler::TraceMeLevel::kInfo);
   std::vector<ComputationClient::ComputationPtr> computations;
 
+  std::cout << "instances.size(): " << instances.size() << std::endl;
   for (auto& instance : instances) {
+    std::cout << "instance devices " << instance.devices << std::endl;
     xla::CompileOptions compile_options;
     if (instance.is_sharded) {
       // TODO(yeounoh) multi-host, multi-slice configurations
@@ -560,6 +567,9 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
               {instance.allow_spmd_sharding_propagation_to_output});
 
       int num_partitions = client_->device_count();
+      // num_partitions = 4;
+      num_partitions = static_cast<int>(instance.computation_num_partitions);
+      std::cout << "num_partitions: " << num_partitions << std::endl;
       compile_options.executable_build_options.set_num_partitions(
           num_partitions);
       compile_options.executable_build_options.set_num_replicas(1);
@@ -589,11 +599,20 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
       }
 
       // TODO(244391366) verify this is correct for the collectives ops
-      xla::DeviceAssignment device_assignment(1, client_->device_count());
+      // xla::DeviceAssignment device_assignment(1, client_->device_count());
+      xla::DeviceAssignment device_assignment(1, num_partitions);
+      std::cout << "check client_->device_count(): " << client_->device_count()
+                << std::endl;
       // DeviceAssignment values must be the PjRtDevice ID, so we need to
       // unwind the global ordinal mapping.
-      for (const auto& [device_id, global_ordinal] : global_ordinals_) {
-        device_assignment(0, global_ordinal) = device_id;
+      // for (const auto& [device_id, global_ordinal] : global_ordinals_) {
+      //   std::cout << "device_id: " << device_id
+      //             << ", global_ordinal: " << global_ordinal << std::endl;
+      //   device_assignment(0, global_ordinal) = device_id;
+      // }
+      auto local_pjrt_devices = client_->addressable_devices();
+      for (int i = 0; i < local_pjrt_devices.size(); ++i) {
+        device_assignment(0, i) = local_pjrt_devices[i]->id();
       }
       compile_options.executable_build_options.set_device_assignment(
           device_assignment);
@@ -649,7 +668,7 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
 
     CreateCompileHandlesCounter()->AddValue(1);
   }
-
+  std::cout << "finish compile" << std::endl;
   return computations;
 }
 
@@ -701,6 +720,7 @@ PjRtComputationClient::ExecuteComputation(
     const ComputationClient::Computation& computation,
     absl::Span<const ComputationClient::DataPtr> arguments,
     const std::string& device, const ExecuteComputationOptions& options) {
+  std::cout << "in execute" << std::endl;
   // Shared ownership of the timed section ensures that it will only get logged
   // once both `ExecuteComputation` and the async work in `ExecuteSharded` are
   // complete; a copy is held from the lambda that releases it when done.
@@ -768,6 +788,7 @@ PjRtComputationClient::ExecuteComputation(
   CreateDataHandlesCounter()->AddValue(datas.size());
 
   TF_VLOG(1) << "Returning " << datas.size() << " results";
+  std::cout << "finish execute" << std::endl;
   return datas;
 }
 
@@ -777,6 +798,10 @@ PjRtComputationClient::ExecuteReplicated(
     absl::Span<const ComputationClient::DataPtr> arguments,
     absl::Span<const std::string> devices,
     const ExecuteReplicatedOptions& options) {
+  std::cout << "in execute replicated" << std::endl;
+  for (auto d : devices) {
+    std::cout << "device: " << d << std::endl;
+  }
   // Shared ownership of the timed section ensures that it will only get logged
   // once both `ExecuteReplicated` and the async work in `Execute` are
   // complete; a copy is held from the lambda that releases it when done.
@@ -914,13 +939,18 @@ PjRtComputationClient::ExecuteReplicated(
   }
 
   TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs.";
+  std::cout << "finish execute replicated" << std::endl;
   return data_handles;
 }
 
-size_t PjRtComputationClient::GetNumDevices() const {
+size_t PjRtComputationClient::GetNumLocalDevices() const {
   return client_->addressable_device_count();
 }
 
+size_t PjRtComputationClient::GetNumGlobalDevices() const {
+  return client_->device_count();
+}
+
 std::string PjRtComputationClient::GetDefaultDevice() const {
   return PjRtDeviceToString(client_->addressable_devices()[0]);
 }
@@ -972,12 +1002,17 @@ xla::PjRtDevice* PjRtComputationClient::StringToPjRtDevice(
 
 void PjRtComputationClient::WaitDeviceOps(
     absl::Span<const std::string> devices) {
+  std::cout << "in wait device ops" << std::endl;
+  for (auto d : devices) {
+    std::cout << "device: " << d << std::endl;
+  }
   TF_VLOG(3) << "Waiting for " << absl::StrJoin(devices, ", ");
   operation_manager_.WaitForDevices(
       devices.empty()
           ? (UseVirtualDevice() ? std::vector<std::string>({spmd_device_str})
                                 : GetLocalDevices())
           : devices);
+  std::cout << "finish wait device ops" << std::endl;
 }
 
 std::map<std::string, Metric> PjRtComputationClient::GetMetrics() const {
diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h
index 9791f32381b..090ff952fdf 100644
--- a/torch_xla/csrc/runtime/pjrt_computation_client.h
+++ b/torch_xla/csrc/runtime/pjrt_computation_client.h
@@ -86,7 +86,9 @@ class PjRtComputationClient : public ComputationClient {
       absl::Span<const std::string> devices,
       const ExecuteReplicatedOptions& options) override;
 
-  size_t GetNumDevices() const override;
+  size_t GetNumLocalDevices() const override;
+
+  size_t GetNumGlobalDevices() const override;
 
   std::string GetDefaultDevice() const override;
 
diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp
index 4e69127ff81..fcf793ff5bc 100644
--- a/torch_xla/csrc/tensor_impl.cpp
+++ b/torch_xla/csrc/tensor_impl.cpp
@@ -57,7 +57,7 @@ struct XLAGuardImpl : public c10::impl::DeviceGuardImplInterface {
       return 0;
     }
 
-    return client->GetNumDevices();
+    return client->GetNumLocalDevices();
   }
 };
 
diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp
index 0b8c5489798..514266518dc 100644
--- a/torch_xla/csrc/xla_graph_executor.cpp
+++ b/torch_xla/csrc/xla_graph_executor.cpp
@@ -1391,12 +1391,16 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
   // Always execute sharded when running in SPMD mode
   bool is_sharded = (coll.device == GetVirtualDevice()) || UseVirtualDevice();
   // Annotate HLO sharding selectively in the compuation.
-  ShardingUtil::SetHloSharding(&lowering_ctx);
+  bool is_sharded_2 = ShardingUtil::SetHloSharding(&lowering_ctx);
+
+  std::cout << "is_sharded_2: " << is_sharded_2 << std::endl;
 
   SetBufferDonors(&lowering_ctx, buffer_donor_indices);
 
   xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla());
   xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape());
+  size_t computation_num_partitions =
+      lowering_ctx.GetComputationNumPartitions();
 
   // TODO(yeounoh) enable wrapping with auto-sharding.
   bool should_wrap_parameter =
@@ -1422,11 +1426,15 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
       program_shape.result(), static_cast<XlaDeviceType>(coll.device.type()));
 
   std::vector<runtime::ComputationClient::CompileInstance> instances;
-  instances.push_back({std::move(computation), coll.device.toString(),
-                       runtime::GetComputationClient()->GetCompilationDevices(
-                           coll.device.toString(), devices),
-                       &shape, should_wrap_parameter, is_sharded});
+  std::cout << "computation_num_partitions: " << computation_num_partitions
+            << std::endl;
+  instances.emplace_back(std::move(computation), coll.device.toString(),
+                         runtime::GetComputationClient()->GetCompilationDevices(
+                             coll.device.toString(), devices),
+                         &shape, should_wrap_parameter, is_sharded,
+                         computation_num_partitions);
   instances.front().eager_mode = UseEagerMode();
+  instances.front().computation_num_partitions = computation_num_partitions;
   if (use_autosharding) {
     TF_VLOG(5) << "use_auto_spmd_partitioning is set.";
     TF_CHECK(is_sharded) << "Auto-sharding pass requires SPMD mode.";
@@ -1455,6 +1463,8 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
   TF_VLOG(3) << "Compiling IR graph hash "
              << torch::lazy::HashToString(coll.hash) << " on device "
              << coll.device << " ...";
+  std::cout << "check instance num partitions"
+            << instances.front().computation_num_partitions << std::endl;
   std::vector<std::shared_ptr<runtime::ComputationClient::Computation>>
       computations =
           runtime::GetComputationClient()->Compile(std::move(instances));
diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp
index d58144d6844..b2938f81dbe 100644
--- a/torch_xla/csrc/xla_sharding_util.cpp
+++ b/torch_xla/csrc/xla_sharding_util.cpp
@@ -85,10 +85,11 @@ std::vector<int64_t> TileAssignmentDimensions(
 // order of the output corresponds to the order of the `devices`, which can be
 // arbitrarily set by the caller.
 std::unordered_map<int, int> build_index_map(
-    const std::vector<std::string>& devices) {
+    const std::vector<std::string>& devices, size_t num_mesh_devices) {
   std::unordered_map<int, int> device_index;
   for (int i = 0; i < devices.size(); ++i) {
-    int global_ordinal = ParseDeviceString(devices[i]).ordinal();
+    int global_ordinal =
+        ParseDeviceString(devices[i]).ordinal() % num_mesh_devices;
     device_index[global_ordinal] = i;
   }
   return device_index;
@@ -191,6 +192,9 @@ bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) {
         XlaBuilderFriend::GetInstruction(elem.second);
     const std::shared_ptr<xla::OpSharding> sharding =
         xla_node->GetSharding(elem.first.index);
+    if (sharding != nullptr) {
+      std::cout << "check opsharding " << sharding->DebugString() << std::endl;
+    }
     if (sharding != nullptr && sharding->type() != xla::OpSharding::UNKNOWN) {
       *instruction->mutable_sharding() = *sharding;
       is_sharded = true;
@@ -371,10 +375,25 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices(
       shard_indices[i] = std::make_pair(global_ordinal, indices);
     }
   } else if (sharding.type() == xla::OpSharding::OTHER) {
-    auto device_index = build_index_map(devices);
     std::vector<int64_t> tile_assignment_devices(
         sharding.tile_assignment_devices().begin(),
         sharding.tile_assignment_devices().end());
+    size_t num_local_devices =
+        runtime::GetComputationClient()->GetNumLocalDevices();
+    size_t num_global_devices =
+        runtime::GetComputationClient()->GetNumGlobalDevices();
+    XLA_CHECK(tile_assignment_devices.size() == num_global_devices ||
+              tile_assignment_devices.size() == num_local_devices)
+        << "Number of tile_assignment_devices must be the number of global "
+           "devices or local devices";
+    std::cout << "Num local devices " << num_local_devices << std::endl;
+    std::unordered_map<int, int> device_index =
+        build_index_map(devices, tile_assignment_devices.size());
+    std::cout << "Check device_index " << std::endl;
+    for (const auto& pair : device_index) {
+      std::cout << "Key: " << pair.first << ", Value: " << pair.second
+                << std::endl;
+    }
     if (!sharding.iota_reshape_dims().empty()) {
       auto tileAssignment = xla::TileAssignment(
           sharding.tile_assignment_dimensions(), sharding.iota_reshape_dims(),
@@ -384,7 +403,10 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices(
     }
     for (size_t i = 0; i < tile_assignment_devices.size(); i++) {
       int64_t core = tile_assignment_devices[i];
+      std::cout << "Check core " << core << std::endl;
       if (device_index.find(core) == device_index.end()) {
+        std::cout << "current core " << core << " is not in device_index"
+                  << std::endl;
         // Skip any shards whose device is not part of the `devices` list.
         continue;
       }
@@ -434,6 +456,8 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices(
 std::vector<at::Tensor> ShardingUtil::ShardTensor(
     const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings,
     const std::vector<std::string>& devices, bool padded) {
+  std::cout << "ShardingUtil::ShardTensor check devices " << devices
+            << std::endl;
   xla::OpSharding sharding;
   bool minibatch = false;
   if (shardings != nullptr) {
@@ -442,7 +466,7 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor(
   }
   TF_VLOG(5) << "ShardTensor with sharding type(" << sharding.type()
              << ")... and minibatch = " << minibatch << std::endl;
-  auto device_index = build_index_map(devices);
+  // auto device_index = build_index_map(devices);
   std::vector<at::Tensor> shards(devices.size());
   if (shardings == nullptr || sharding.type() == xla::OpSharding::REPLICATED ||
       sharding.type() == xla::OpSharding::UNKNOWN) {
@@ -464,6 +488,8 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor(
                      std::back_inserter(shard_indices),
                      [](auto& pair) { return pair.second; });
     }
+    std::cout << "ShardingUtil::ShardTensor check shard_indices: "
+              << shard_indices << std::endl;
 
     for (size_t i = 0; i < shard_indices.size(); i++) {
       at::Tensor shard = tensor.index(

From 2f33433c5562e58a6ba58bb08338e50dc4a10d93 Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Sun, 9 Mar 2025 22:06:13 +0000
Subject: [PATCH 03/16] skip no tile assignment device case for num partition
 retrieving in lowering context

---
 torch_xla/csrc/lowering_context.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp
index a004be88c54..5a6621bb49a 100644
--- a/torch_xla/csrc/lowering_context.cpp
+++ b/torch_xla/csrc/lowering_context.cpp
@@ -355,6 +355,9 @@ void LoweringContext::UpdateNumPartitions(const xla::XlaOp& op) {
   if (op_sharding.has_value()) {
     size_t curr_num_partitions =
         op_sharding.value().tile_assignment_devices().size();
+    if (curr_num_partitions == 0) {
+      return;
+    }
     if (num_computation_partitions_ != 1) {
       XLA_CHECK_EQ(curr_num_partitions, num_computation_partitions_)
           << "Number of partitions must be the same for all ops in a HLO "

From b633c76ef9a4ab945a41619b61f3e7175f28b628 Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Sun, 9 Mar 2025 22:06:54 +0000
Subject: [PATCH 04/16] use env var for local spmd

---
 torch_xla/csrc/runtime/computation_client.h       | 2 +-
 torch_xla/csrc/runtime/pjrt_computation_client.cc | 8 +++++---
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h
index 339d2a4f52c..bc01a9af33d 100644
--- a/torch_xla/csrc/runtime/computation_client.h
+++ b/torch_xla/csrc/runtime/computation_client.h
@@ -250,7 +250,7 @@ class ComputationClient {
     const xla::Shape* output_shape = nullptr;
     bool parameter_is_tupled_arguments;
     bool is_sharded;
-    size_t computation_num_partitions;
+    size_t computation_num_partitions = 1;
     bool allow_spmd_sharding_propagation_to_output;
     bool use_auto_spmd_partitioning;
     std::vector<int64_t> auto_spmd_mesh_shape;
diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc
index 6bf6217c036..a81a16c0fb7 100644
--- a/torch_xla/csrc/runtime/pjrt_computation_client.cc
+++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc
@@ -566,9 +566,11 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
           .set_allow_spmd_sharding_propagation_to_output(
               {instance.allow_spmd_sharding_propagation_to_output});
 
-      int num_partitions = client_->device_count();
-      // num_partitions = 4;
-      num_partitions = static_cast<int>(instance.computation_num_partitions);
+      int num_partitions = GetNumGlobalDevices();
+      if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) {
+        num_partitions = GetNumLocalDevices();
+      }
+      // num_partitions = static_cast<int>(instance.computation_num_partitions);
       std::cout << "num_partitions: " << num_partitions << std::endl;
       compile_options.executable_build_options.set_num_partitions(
           num_partitions);

From ba4b480ea94540ca4b8b849a8ef9c45446b4b5f5 Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Mon, 10 Mar 2025 00:36:47 +0000
Subject: [PATCH 05/16] get num partitions from prod of tile dims

---
 torch_xla/csrc/xla_sharding_util.cpp | 17 ++++++++++++-----
 1 file changed, 12 insertions(+), 5 deletions(-)

diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp
index b2938f81dbe..00727187e37 100644
--- a/torch_xla/csrc/xla_sharding_util.cpp
+++ b/torch_xla/csrc/xla_sharding_util.cpp
@@ -382,13 +382,20 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices(
         runtime::GetComputationClient()->GetNumLocalDevices();
     size_t num_global_devices =
         runtime::GetComputationClient()->GetNumGlobalDevices();
-    XLA_CHECK(tile_assignment_devices.size() == num_global_devices ||
-              tile_assignment_devices.size() == num_local_devices)
-        << "Number of tile_assignment_devices must be the number of global "
-           "devices or local devices";
+    // XLA_CHECK(tile_assignment_devices.size() == 0 ||
+    //           tile_assignment_devices.size() == num_global_devices ||
+    //           tile_assignment_devices.size() == num_local_devices)
+    //     << "Number of tile_assignment_devices must be the number of global "
+    //        "devices or local devices, or 0, got unexpected size of "
+    //     << tile_assignment_devices.size();
+    size_t num_tiles = std::accumulate(
+      sharding.tile_assignment_dimensions().begin(),
+      sharding.tile_assignment_dimensions().end(), 1, 
+        [](int a, int b) { return a * b; });
     std::cout << "Num local devices " << num_local_devices << std::endl;
+    std::cout << "Num tile assignment size " << tile_assignment_devices.size() << std::endl;
     std::unordered_map<int, int> device_index =
-        build_index_map(devices, tile_assignment_devices.size());
+        build_index_map(devices, num_tiles);
     std::cout << "Check device_index " << std::endl;
     for (const auto& pair : device_index) {
       std::cout << "Key: " << pair.first << ", Value: " << pair.second

From 7561786cb5c4acd1b1c98fe62c603bae390b8dc7 Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Mon, 10 Mar 2025 00:37:07 +0000
Subject: [PATCH 06/16] clang

---
 torch_xla/csrc/xla_sharding_util.cpp | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp
index 00727187e37..5e168433a44 100644
--- a/torch_xla/csrc/xla_sharding_util.cpp
+++ b/torch_xla/csrc/xla_sharding_util.cpp
@@ -388,12 +388,13 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices(
     //     << "Number of tile_assignment_devices must be the number of global "
     //        "devices or local devices, or 0, got unexpected size of "
     //     << tile_assignment_devices.size();
-    size_t num_tiles = std::accumulate(
-      sharding.tile_assignment_dimensions().begin(),
-      sharding.tile_assignment_dimensions().end(), 1, 
-        [](int a, int b) { return a * b; });
+    size_t num_tiles =
+        std::accumulate(sharding.tile_assignment_dimensions().begin(),
+                        sharding.tile_assignment_dimensions().end(), 1,
+                        [](int a, int b) { return a * b; });
     std::cout << "Num local devices " << num_local_devices << std::endl;
-    std::cout << "Num tile assignment size " << tile_assignment_devices.size() << std::endl;
+    std::cout << "Num tile assignment size " << tile_assignment_devices.size()
+              << std::endl;
     std::unordered_map<int, int> device_index =
         build_index_map(devices, num_tiles);
     std::cout << "Check device_index " << std::endl;

From ff12d44a855d3a61b13aa55483ccc568694d3f36 Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Mon, 10 Mar 2025 00:58:26 +0000
Subject: [PATCH 07/16] add a test for shard tensor for local mesh

---
 test/cpp/test_xla_sharding.cpp | 107 +++++++++++++++++++++++++++++++++
 1 file changed, 107 insertions(+)

diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp
index e1f908b5c80..ed9b5b7677c 100644
--- a/test/cpp/test_xla_sharding.cpp
+++ b/test/cpp/test_xla_sharding.cpp
@@ -222,6 +222,113 @@ TEST_F(XLAShardingTest, ShardTensor) {
   EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({10, 1, 4, 4, 2}));
 }
 
+TEST_F(XLAShardingTest, ShardTensorLocalMesh) {
+  // Test sharding with a local mesh.
+  std::vector<std::string> devices = {"TPU:8",  "TPU:9",  "TPU:10", "TPU:11",
+                                      "TPU:12", "TPU:13", "TPU:14", "TPU:15"};
+
+  // 1D tiled
+  at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
+  xla::Shape tensor_shape =
+      CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
+  xla::OpSharding sharding =
+      xla::HloSharding::Tile1D(
+          CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()),
+          devices.size())
+          .ToProto();
+  auto sharding_spec =
+      std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
+  auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
+                                          /*padded=*/false);
+  EXPECT_EQ(shards.size(), 8);
+  for (auto shard : shards) {
+    EXPECT_EQ(shard.sizes(), c10::ArrayRef<long>({1}));
+  }
+
+  // 2D tiled, The first dim is halved and the last replicated. The last shard
+  // size should be smaller in dim=1 because it's not evenly divisible.
+  tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat));
+  tensor_shape =
+      CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
+  xla::Array2D<int64_t> mesh({
+      {0, 1, 2, 3},
+      {4, 5, 6, 7},
+  });
+  sharding = xla::HloSharding::Tile(mesh).ToProto();
+  sharding_spec =
+      std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
+  shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
+                                     /*padded=*/false);
+  EXPECT_EQ(shards.size(), 8);
+  EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({4, 2, 4}));
+  EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({4, 1, 4}));
+
+  // 3D tiled, the first dim is replicated and the last halved. The last shard
+  // size should be smaller in dim=1 because it's not evenly divisible.
+  xla::Array3D<int64_t> cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}});
+  sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto();
+  shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
+                                     /*padded=*/false);
+  EXPECT_EQ(shards.size(), 8);
+  EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({8, 2, 2}));
+  EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({8, 1, 2}));
+
+  // Replicated, all shards should be identical.
+  sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();
+  shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
+                                     /*padded=*/false);
+  EXPECT_EQ(shards.size(), 8);
+  EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({8, 7, 4}));
+  EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({8, 7, 4}));
+
+  // 4D tiled, the first and second dims are replicated and the last halved. The
+  // last shard size should be smaller in dim=2 because it's not evenly
+  // divisible.
+  tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat));
+  tensor_shape =
+      CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
+  xla::Array4D<int64_t> tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}});
+  sharding = xla::HloSharding::Tile(tesseract).ToProto();
+  sharding_spec =
+      std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
+  shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
+                                     /*padded=*/false);
+  EXPECT_EQ(shards.size(), 8);
+  EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({1, 8, 2, 2}));
+  EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({1, 8, 1, 2}));
+
+  // 4D tiled and padded, all shard sizes should be idential.
+  shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
+                                     /*padded=*/true);
+  EXPECT_EQ(shards.size(), 8);
+  EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({1, 8, 2, 2}));
+  EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({1, 8, 2, 2}));
+
+  // 5D tiled, the first and second dims are replicated and the last halved. The
+  // last shard size should be smaller in dim=2 because it's not evenly
+  // divisible.
+  tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat));
+  tensor_shape =
+      CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
+  xla::Array<int64_t> hypercube(std::vector<int64_t>{1, 1, 2, 2, 2});
+  hypercube.FillIota(0);
+  sharding = xla::HloSharding::Tile(hypercube).ToProto();
+  sharding_spec =
+      std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
+  shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
+                                     /*padded=*/false);
+  EXPECT_EQ(shards.size(), 8);
+  EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({10, 1, 4, 4, 2}));
+  EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({10, 1, 4, 3, 2}));
+
+  // 5D tiled and padded, all shard sizes should be identical.
+  shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
+                                     /*padded=*/true);
+  EXPECT_EQ(shards.size(), 8);
+  EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({10, 1, 4, 4, 2}));
+  EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({10, 1, 4, 4, 2}));
+}
+
 TEST_F(XLAShardingTest, ShardTensorMultiHost) {
   std::vector<std::string> devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"};
 

From 98731291ef9820262a7e54d1fd417154843b3a56 Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Mon, 10 Mar 2025 01:11:27 +0000
Subject: [PATCH 08/16] use env var for device assignment handling

---
 .../csrc/runtime/pjrt_computation_client.cc   | 20 +++++++++----------
 1 file changed, 9 insertions(+), 11 deletions(-)

diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc
index a81a16c0fb7..94160ca4949 100644
--- a/torch_xla/csrc/runtime/pjrt_computation_client.cc
+++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc
@@ -601,20 +601,18 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
       }
 
       // TODO(244391366) verify this is correct for the collectives ops
-      // xla::DeviceAssignment device_assignment(1, client_->device_count());
       xla::DeviceAssignment device_assignment(1, num_partitions);
-      std::cout << "check client_->device_count(): " << client_->device_count()
-                << std::endl;
       // DeviceAssignment values must be the PjRtDevice ID, so we need to
       // unwind the global ordinal mapping.
-      // for (const auto& [device_id, global_ordinal] : global_ordinals_) {
-      //   std::cout << "device_id: " << device_id
-      //             << ", global_ordinal: " << global_ordinal << std::endl;
-      //   device_assignment(0, global_ordinal) = device_id;
-      // }
-      auto local_pjrt_devices = client_->addressable_devices();
-      for (int i = 0; i < local_pjrt_devices.size(); ++i) {
-        device_assignment(0, i) = local_pjrt_devices[i]->id();
+      if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) {
+        auto local_pjrt_devices = client_->addressable_devices();
+        for (int i = 0; i < local_pjrt_devices.size(); ++i) {
+          device_assignment(0, i) = local_pjrt_devices[i]->id();
+        }
+      } else {
+        for (const auto& [device_id, global_ordinal] : global_ordinals_) {
+          device_assignment(0, global_ordinal) = device_id;
+        }
       }
       compile_options.executable_build_options.set_device_assignment(
           device_assignment);

From e4183509ecf103b3b76f9a8518b5f14ab2a431f2 Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Mon, 10 Mar 2025 01:31:37 +0000
Subject: [PATCH 09/16] add assertion, comment for xla sharding python api

---
 torch_xla/distributed/spmd/xla_sharding.py | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py
index de2714ad249..dc82af375af 100644
--- a/torch_xla/distributed/spmd/xla_sharding.py
+++ b/torch_xla/distributed/spmd/xla_sharding.py
@@ -63,12 +63,15 @@ def __init__(self,
       device_ids = np.array(device_ids)
     assert (axis_names is None) or (len(mesh_shape) == len(axis_names))
     assert axis_names is None or (len(set(axis_names)) == len(axis_names))
+    # size of device_ids matches mesh_shape
     assert (len(device_ids) == np.prod(mesh_shape))
+    # device ids are unique
     assert len(device_ids) == len(np.unique(device_ids))
+    # device ids are continous
+    assert all(d < self.size() for d in device_ids - np.min(device_ids))
     self.device_ids = device_ids
     self.mesh_shape = mesh_shape
     self.axis_names = axis_names
-    # assert all(d < self.size() for d in device_ids)
 
   def size(self):
     return np.prod(self.mesh_shape)
@@ -382,6 +385,15 @@ def _get_sharding_type(partition_spec: Tuple[Union[int, None]],
 
 
 def _normalize_logical_mesh(device_mesh: np.ndarray) -> np.ndarray:
+  """
+  Normalize the device mesh to start from 0.
+  
+  This is needed when mesh doesn't include all global devices
+  (e.g. In multi-host setup, each host has a mesh containing local devices).
+  Because HLO graph always use logical device ids in the sharding annotation,
+  we need to normalize the physical device ids to generate the correct HLO
+  sharding annotation.
+  """
   device_id_min = np.min(device_mesh)
   return device_mesh.copy() - device_id_min
 

From e2d157b3f73517ff9a75924c50ab3f9c69d2a922 Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Mon, 10 Mar 2025 01:42:13 +0000
Subject: [PATCH 10/16] fix assertion

---
 torch_xla/distributed/spmd/xla_sharding.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py
index dc82af375af..f16a9906887 100644
--- a/torch_xla/distributed/spmd/xla_sharding.py
+++ b/torch_xla/distributed/spmd/xla_sharding.py
@@ -67,11 +67,11 @@ def __init__(self,
     assert (len(device_ids) == np.prod(mesh_shape))
     # device ids are unique
     assert len(device_ids) == len(np.unique(device_ids))
-    # device ids are continous
-    assert all(d < self.size() for d in device_ids - np.min(device_ids))
     self.device_ids = device_ids
     self.mesh_shape = mesh_shape
     self.axis_names = axis_names
+    # device ids are continous
+    assert all(d < self.size() for d in device_ids - np.min(device_ids))
 
   def size(self):
     return np.prod(self.mesh_shape)

From 3865f67e6847111899a324c475182bc82a7af5cf Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Mon, 10 Mar 2025 02:57:42 +0000
Subject: [PATCH 11/16] remove debug print, attemp to derive num partitions
 from lowering

---
 torch_xla/csrc/lowering_context.cpp           | 45 -------------------
 torch_xla/csrc/lowering_context.h             | 10 -----
 torch_xla/csrc/runtime/computation_client.h   |  3 --
 .../csrc/runtime/pjrt_computation_client.cc   | 24 +---------
 torch_xla/csrc/xla_graph_executor.cpp         | 14 +-----
 torch_xla/csrc/xla_sharding_util.cpp          | 34 ++------------
 torch_xla/distributed/spmd/xla_sharding.py    |  4 --
 7 files changed, 6 insertions(+), 128 deletions(-)

diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp
index 5a6621bb49a..6c2906dc724 100644
--- a/torch_xla/csrc/lowering_context.cpp
+++ b/torch_xla/csrc/lowering_context.cpp
@@ -93,7 +93,6 @@ LoweringContext::LoweringContext(const std::string& name,
                                  torch::lazy::BackendDevice device)
     : torch::lazy::LoweringContext(name, device),
       builder_(name),
-      num_computation_partitions_(1),
       stack_frame_index_builder_(std::make_shared<StackFrameIndexBuilder>()) {}
 
 LoweringContext::LoweringContext(
@@ -102,7 +101,6 @@ LoweringContext::LoweringContext(
     torch::lazy::Util::EmissionMap emit_status)
     : torch::lazy::LoweringContext(name, device, {}, emit_status),
       builder_(name),
-      num_computation_partitions_(1),
       stack_frame_index_builder_(std::make_shared<StackFrameIndexBuilder>()) {
   for (auto node : post_order) {
     LowerNode(node);
@@ -133,7 +131,6 @@ xla::XlaOp LoweringContext::GetParameter(
       xla::OpSharding sharding = data->GetSharding();
       xla::XlaScopedShardingAssignment scoped_sharding(builder(), sharding);
       param = xla::Parameter(builder(), param_index, shape, param_name);
-      UpdateNumPartitions(param);
     } else {
       param = xla::Parameter(builder(), param_index, shape, param_name);
     }
@@ -257,28 +254,6 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) {
         mutable_dims->Set(dim, kUnboundedSize);
       }
     }
-    std::for_each(result_ops.begin(), result_ops.end(),
-                  [this](xla::XlaOp xla_op) {
-                    UpdateNumPartitions(xla_op);  // Calling the member function
-                  });
-    // for (auto xla_op : result_ops) {
-    //   UpdateNumPartitions(xla_op);
-    //   // std::optional<OpSharding> op_sharding =
-    //   //   ConsumeValue(builder()->GetOpSharding(xla_op));
-    //   // if (op_sharding.has_value()) {
-    //   //   size_t curr_num_partitions =
-    //   //     op_sharding.value().tile_assignment_devices().size();
-    //   //   if (num_computation_partitions_ != 1) {
-    //   //     XLA_CHECK_EQ(curr_num_partitions, num_computation_partitions_)
-    //   <<
-    //   //       "Number of partitions must be the same for all ops in a HLO
-    //   graph.";
-    //   //     continue;
-    //   //   }
-    //   //   num_computation_partitions_ =
-    //   op_sharding.value().tile_assignment_devices().size();
-    //   // }
-    // }
   } catch (const std::exception& ex) {
     ReportBuilderError(node, ex.what());
   }
@@ -349,24 +324,4 @@ torch::lazy::ComputationPtr LoweringContext::Build() {
       builder_.name(), std::move(xla_computation), device_);
 }
 
-void LoweringContext::UpdateNumPartitions(const xla::XlaOp& op) {
-  std::optional<xla::OpSharding> op_sharding =
-      ConsumeValue(builder()->GetOpSharding(op));
-  if (op_sharding.has_value()) {
-    size_t curr_num_partitions =
-        op_sharding.value().tile_assignment_devices().size();
-    if (curr_num_partitions == 0) {
-      return;
-    }
-    if (num_computation_partitions_ != 1) {
-      XLA_CHECK_EQ(curr_num_partitions, num_computation_partitions_)
-          << "Number of partitions must be the same for all ops in a HLO "
-             "graph.";
-      return;
-    }
-    std::cout << "curr_num_partitions: " << curr_num_partitions << std::endl;
-    num_computation_partitions_ = curr_num_partitions;
-  }
-}
-
 }  // namespace torch_xla
diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h
index fdaabb2b14d..cb4f0bc2d2f 100644
--- a/torch_xla/csrc/lowering_context.h
+++ b/torch_xla/csrc/lowering_context.h
@@ -113,18 +113,10 @@ class LoweringContext : public torch::lazy::LoweringContext {
     return emitted_outputs_;
   }
 
-  size_t GetComputationNumPartitions() const {
-    return num_computation_partitions_;
-  }
-
   // Return stack frame id
   int64_t AddStackFrameLocation(const torch::lazy::SourceLocation& source,
                                 int64_t parent_id);
 
- protected:
-  // Update the number of partitions from a XlaOp.
-  void UpdateNumPartitions(const xla::XlaOp& op);
-
  private:
   struct Parameter {
     xla::XlaOp param;
@@ -141,8 +133,6 @@ class LoweringContext : public torch::lazy::LoweringContext {
   std::vector<xla::XlaOp> root_tuple_;
   OutputMap<xla::XlaOp> emitted_outputs_;
   std::string name_;
-  // Number of partitions of the lowered XLA computation.
-  size_t num_computation_partitions_;
 
   std::shared_ptr<StackFrameIndexBuilder> stack_frame_index_builder_;
 };  // namespace torch_xla
diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h
index bc01a9af33d..20915de32e2 100644
--- a/torch_xla/csrc/runtime/computation_client.h
+++ b/torch_xla/csrc/runtime/computation_client.h
@@ -225,7 +225,6 @@ class ComputationClient {
         xla::XlaComputation computation, std::string compilation_device,
         std::vector<std::string> devices, const xla::Shape* output_shape,
         bool parameter_is_tupled_arguments = false, bool is_sharded = false,
-        size_t computation_num_partitions = 1,
         bool allow_spmd_sharding_propagation_to_output = true,
         bool use_auto_spmd_partitioning = false,
         std::vector<int64_t> auto_spmd_mesh_shape = {},
@@ -236,7 +235,6 @@ class ComputationClient {
           output_shape(output_shape),
           parameter_is_tupled_arguments(parameter_is_tupled_arguments),
           is_sharded(is_sharded),
-          computation_num_partitions(computation_num_partitions),
           allow_spmd_sharding_propagation_to_output(
               allow_spmd_sharding_propagation_to_output),
           use_auto_spmd_partitioning(use_auto_spmd_partitioning),
@@ -250,7 +248,6 @@ class ComputationClient {
     const xla::Shape* output_shape = nullptr;
     bool parameter_is_tupled_arguments;
     bool is_sharded;
-    size_t computation_num_partitions = 1;
     bool allow_spmd_sharding_propagation_to_output;
     bool use_auto_spmd_partitioning;
     std::vector<int64_t> auto_spmd_mesh_shape;
diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc
index 94160ca4949..3783bb61b5d 100644
--- a/torch_xla/csrc/runtime/pjrt_computation_client.cc
+++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc
@@ -334,7 +334,6 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice(
 std::shared_ptr<PjRtComputationClient::PjRtData>
 PjRtComputationClient::ReplicateShardedData(
     const ComputationClient::DataPtr& handle) {
-  std::cout << "PjRtComputationClient::ReplicateShardedData" << std::endl;
   if (auto unsharded_data = std::dynamic_pointer_cast<PjRtData>(handle)) {
     return unsharded_data;
   } else if (auto sharded_data =
@@ -348,9 +347,7 @@ PjRtComputationClient::ReplicateShardedData(
     }
     xla::XlaBuilder builder("ReplicateShardedData");
     xla::Shape shape = sharded_data->shape();
-    xla::OpSharding sharding = sharded_data->GetSharding();
-    builder.SetSharding(sharding);
-    size_t num_partitions = sharding.tile_assignment_devices().size();
+    builder.SetSharding(sharded_data->GetSharding());
 
     // perform a simple identity calculation to reassemble the input as
     // replicated output.
@@ -374,7 +371,6 @@ PjRtComputationClient::ReplicateShardedData(
                          GetCompilationDevices(device, {}), &shape,
                          /*should_wrap_parameter=*/false,
                          /*is_sharded=*/true,
-                         /*computation_num_partitions*/ num_partitions,
                          /*allow_spmd_sharding_propagation_to_output=*/false});
     std::vector<
         std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
@@ -541,7 +537,6 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(
 
 std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
     std::vector<ComputationClient::CompileInstance> instances) {
-  std::cout << "in compile" << std::endl;
   auto metrics_fn = CompileMetric;
   if (instances[0].eager_mode) {
     metrics_fn = EagerCompileMetric;
@@ -551,9 +546,7 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
                                   tsl::profiler::TraceMeLevel::kInfo);
   std::vector<ComputationClient::ComputationPtr> computations;
 
-  std::cout << "instances.size(): " << instances.size() << std::endl;
   for (auto& instance : instances) {
-    std::cout << "instance devices " << instance.devices << std::endl;
     xla::CompileOptions compile_options;
     if (instance.is_sharded) {
       // TODO(yeounoh) multi-host, multi-slice configurations
@@ -570,8 +563,6 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
       if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) {
         num_partitions = GetNumLocalDevices();
       }
-      // num_partitions = static_cast<int>(instance.computation_num_partitions);
-      std::cout << "num_partitions: " << num_partitions << std::endl;
       compile_options.executable_build_options.set_num_partitions(
           num_partitions);
       compile_options.executable_build_options.set_num_replicas(1);
@@ -668,7 +659,6 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
 
     CreateCompileHandlesCounter()->AddValue(1);
   }
-  std::cout << "finish compile" << std::endl;
   return computations;
 }
 
@@ -720,7 +710,6 @@ PjRtComputationClient::ExecuteComputation(
     const ComputationClient::Computation& computation,
     absl::Span<const ComputationClient::DataPtr> arguments,
     const std::string& device, const ExecuteComputationOptions& options) {
-  std::cout << "in execute" << std::endl;
   // Shared ownership of the timed section ensures that it will only get logged
   // once both `ExecuteComputation` and the async work in `ExecuteSharded` are
   // complete; a copy is held from the lambda that releases it when done.
@@ -788,7 +777,6 @@ PjRtComputationClient::ExecuteComputation(
   CreateDataHandlesCounter()->AddValue(datas.size());
 
   TF_VLOG(1) << "Returning " << datas.size() << " results";
-  std::cout << "finish execute" << std::endl;
   return datas;
 }
 
@@ -798,10 +786,6 @@ PjRtComputationClient::ExecuteReplicated(
     absl::Span<const ComputationClient::DataPtr> arguments,
     absl::Span<const std::string> devices,
     const ExecuteReplicatedOptions& options) {
-  std::cout << "in execute replicated" << std::endl;
-  for (auto d : devices) {
-    std::cout << "device: " << d << std::endl;
-  }
   // Shared ownership of the timed section ensures that it will only get logged
   // once both `ExecuteReplicated` and the async work in `Execute` are
   // complete; a copy is held from the lambda that releases it when done.
@@ -939,7 +923,6 @@ PjRtComputationClient::ExecuteReplicated(
   }
 
   TF_VLOG(1) << "Returning " << data_handles.size() << " sharded outputs.";
-  std::cout << "finish execute replicated" << std::endl;
   return data_handles;
 }
 
@@ -1002,17 +985,12 @@ xla::PjRtDevice* PjRtComputationClient::StringToPjRtDevice(
 
 void PjRtComputationClient::WaitDeviceOps(
     absl::Span<const std::string> devices) {
-  std::cout << "in wait device ops" << std::endl;
-  for (auto d : devices) {
-    std::cout << "device: " << d << std::endl;
-  }
   TF_VLOG(3) << "Waiting for " << absl::StrJoin(devices, ", ");
   operation_manager_.WaitForDevices(
       devices.empty()
           ? (UseVirtualDevice() ? std::vector<std::string>({spmd_device_str})
                                 : GetLocalDevices())
           : devices);
-  std::cout << "finish wait device ops" << std::endl;
 }
 
 std::map<std::string, Metric> PjRtComputationClient::GetMetrics() const {
diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp
index 514266518dc..c33b5431455 100644
--- a/torch_xla/csrc/xla_graph_executor.cpp
+++ b/torch_xla/csrc/xla_graph_executor.cpp
@@ -1391,16 +1391,12 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
   // Always execute sharded when running in SPMD mode
   bool is_sharded = (coll.device == GetVirtualDevice()) || UseVirtualDevice();
   // Annotate HLO sharding selectively in the compuation.
-  bool is_sharded_2 = ShardingUtil::SetHloSharding(&lowering_ctx);
-
-  std::cout << "is_sharded_2: " << is_sharded_2 << std::endl;
+  ShardingUtil::SetHloSharding(&lowering_ctx);
 
   SetBufferDonors(&lowering_ctx, buffer_donor_indices);
 
   xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla());
   xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape());
-  size_t computation_num_partitions =
-      lowering_ctx.GetComputationNumPartitions();
 
   // TODO(yeounoh) enable wrapping with auto-sharding.
   bool should_wrap_parameter =
@@ -1426,15 +1422,11 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
       program_shape.result(), static_cast<XlaDeviceType>(coll.device.type()));
 
   std::vector<runtime::ComputationClient::CompileInstance> instances;
-  std::cout << "computation_num_partitions: " << computation_num_partitions
-            << std::endl;
   instances.emplace_back(std::move(computation), coll.device.toString(),
                          runtime::GetComputationClient()->GetCompilationDevices(
                              coll.device.toString(), devices),
-                         &shape, should_wrap_parameter, is_sharded,
-                         computation_num_partitions);
+                         &shape, should_wrap_parameter, is_sharded);
   instances.front().eager_mode = UseEagerMode();
-  instances.front().computation_num_partitions = computation_num_partitions;
   if (use_autosharding) {
     TF_VLOG(5) << "use_auto_spmd_partitioning is set.";
     TF_CHECK(is_sharded) << "Auto-sharding pass requires SPMD mode.";
@@ -1463,8 +1455,6 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
   TF_VLOG(3) << "Compiling IR graph hash "
              << torch::lazy::HashToString(coll.hash) << " on device "
              << coll.device << " ...";
-  std::cout << "check instance num partitions"
-            << instances.front().computation_num_partitions << std::endl;
   std::vector<std::shared_ptr<runtime::ComputationClient::Computation>>
       computations =
           runtime::GetComputationClient()->Compile(std::move(instances));
diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp
index 5e168433a44..c4399c22be1 100644
--- a/torch_xla/csrc/xla_sharding_util.cpp
+++ b/torch_xla/csrc/xla_sharding_util.cpp
@@ -192,9 +192,6 @@ bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) {
         XlaBuilderFriend::GetInstruction(elem.second);
     const std::shared_ptr<xla::OpSharding> sharding =
         xla_node->GetSharding(elem.first.index);
-    if (sharding != nullptr) {
-      std::cout << "check opsharding " << sharding->DebugString() << std::endl;
-    }
     if (sharding != nullptr && sharding->type() != xla::OpSharding::UNKNOWN) {
       *instruction->mutable_sharding() = *sharding;
       is_sharded = true;
@@ -375,33 +372,15 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices(
       shard_indices[i] = std::make_pair(global_ordinal, indices);
     }
   } else if (sharding.type() == xla::OpSharding::OTHER) {
-    std::vector<int64_t> tile_assignment_devices(
-        sharding.tile_assignment_devices().begin(),
-        sharding.tile_assignment_devices().end());
-    size_t num_local_devices =
-        runtime::GetComputationClient()->GetNumLocalDevices();
-    size_t num_global_devices =
-        runtime::GetComputationClient()->GetNumGlobalDevices();
-    // XLA_CHECK(tile_assignment_devices.size() == 0 ||
-    //           tile_assignment_devices.size() == num_global_devices ||
-    //           tile_assignment_devices.size() == num_local_devices)
-    //     << "Number of tile_assignment_devices must be the number of global "
-    //        "devices or local devices, or 0, got unexpected size of "
-    //     << tile_assignment_devices.size();
     size_t num_tiles =
         std::accumulate(sharding.tile_assignment_dimensions().begin(),
                         sharding.tile_assignment_dimensions().end(), 1,
                         [](int a, int b) { return a * b; });
-    std::cout << "Num local devices " << num_local_devices << std::endl;
-    std::cout << "Num tile assignment size " << tile_assignment_devices.size()
-              << std::endl;
     std::unordered_map<int, int> device_index =
         build_index_map(devices, num_tiles);
-    std::cout << "Check device_index " << std::endl;
-    for (const auto& pair : device_index) {
-      std::cout << "Key: " << pair.first << ", Value: " << pair.second
-                << std::endl;
-    }
+    std::vector<int64_t> tile_assignment_devices(
+        sharding.tile_assignment_devices().begin(),
+        sharding.tile_assignment_devices().end());
     if (!sharding.iota_reshape_dims().empty()) {
       auto tileAssignment = xla::TileAssignment(
           sharding.tile_assignment_dimensions(), sharding.iota_reshape_dims(),
@@ -411,10 +390,7 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices(
     }
     for (size_t i = 0; i < tile_assignment_devices.size(); i++) {
       int64_t core = tile_assignment_devices[i];
-      std::cout << "Check core " << core << std::endl;
       if (device_index.find(core) == device_index.end()) {
-        std::cout << "current core " << core << " is not in device_index"
-                  << std::endl;
         // Skip any shards whose device is not part of the `devices` list.
         continue;
       }
@@ -464,8 +440,6 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices(
 std::vector<at::Tensor> ShardingUtil::ShardTensor(
     const at::Tensor& tensor, const XLATensor::ShardingSpecPtr shardings,
     const std::vector<std::string>& devices, bool padded) {
-  std::cout << "ShardingUtil::ShardTensor check devices " << devices
-            << std::endl;
   xla::OpSharding sharding;
   bool minibatch = false;
   if (shardings != nullptr) {
@@ -496,8 +470,6 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor(
                      std::back_inserter(shard_indices),
                      [](auto& pair) { return pair.second; });
     }
-    std::cout << "ShardingUtil::ShardTensor check shard_indices: "
-              << shard_indices << std::endl;
 
     for (size_t i = 0; i < shard_indices.size(); i++) {
       at::Tensor shard = tensor.index(
diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py
index f16a9906887..4bc0b71318f 100644
--- a/torch_xla/distributed/spmd/xla_sharding.py
+++ b/torch_xla/distributed/spmd/xla_sharding.py
@@ -130,10 +130,6 @@ def get_op_sharding(self,
 
     tile_assignment, group_assignment, replication_groups, sharding_type = self._get_op_sharding_args(
         partition_spec)
-    print(f"check tile_assignment: {tile_assignment}")
-    print(f"check group_assignment: {group_assignment}")
-    print(f"check replication_groups: {replication_groups}")
-    print(f"check sharding_type: {sharding_type}")
     return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment,
                                       replication_groups, sharding_type)
 

From d3feb5f7097f082897086f0f7add7e06b13e488a Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Mon, 10 Mar 2025 06:00:21 +0000
Subject: [PATCH 12/16] remove unused var

---
 torch_xla/csrc/xla_sharding_util.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp
index c4399c22be1..4289cf0e00c 100644
--- a/torch_xla/csrc/xla_sharding_util.cpp
+++ b/torch_xla/csrc/xla_sharding_util.cpp
@@ -448,7 +448,6 @@ std::vector<at::Tensor> ShardingUtil::ShardTensor(
   }
   TF_VLOG(5) << "ShardTensor with sharding type(" << sharding.type()
              << ")... and minibatch = " << minibatch << std::endl;
-  // auto device_index = build_index_map(devices);
   std::vector<at::Tensor> shards(devices.size());
   if (shardings == nullptr || sharding.type() == xla::OpSharding::REPLICATED ||
       sharding.type() == xla::OpSharding::UNKNOWN) {

From 2f1fc1a411cb7bb4abc9181aeb4197002278ca34 Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Thu, 13 Mar 2025 06:13:21 +0000
Subject: [PATCH 13/16] add comment for the modular of in device index map util
 func

---
 torch_xla/csrc/xla_sharding_util.cpp | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp
index 4289cf0e00c..58bc0ac2053 100644
--- a/torch_xla/csrc/xla_sharding_util.cpp
+++ b/torch_xla/csrc/xla_sharding_util.cpp
@@ -88,6 +88,11 @@ std::unordered_map<int, int> build_index_map(
     const std::vector<std::string>& devices, size_t num_mesh_devices) {
   std::unordered_map<int, int> device_index;
   for (int i = 0; i < devices.size(); ++i) {
+    // The global ordianl here is the device's ordinal in the mesh, which is
+    // can be different from the physical device index.
+    // We only support 2 cases here:
+    // 1. Mesh contains all global devices.
+    // 2. Mesh contains only local devices. (in multi-host scenario)
     int global_ordinal =
         ParseDeviceString(devices[i]).ordinal() % num_mesh_devices;
     device_index[global_ordinal] = i;

From a1eaaebfd4c472cbc61ec4fc3937eb4f92aea55c Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Thu, 13 Mar 2025 06:17:15 +0000
Subject: [PATCH 14/16] udpate comment

---
 torch_xla/csrc/xla_sharding_util.cpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp
index 58bc0ac2053..2058e7490a6 100644
--- a/torch_xla/csrc/xla_sharding_util.cpp
+++ b/torch_xla/csrc/xla_sharding_util.cpp
@@ -93,6 +93,10 @@ std::unordered_map<int, int> build_index_map(
     // We only support 2 cases here:
     // 1. Mesh contains all global devices.
     // 2. Mesh contains only local devices. (in multi-host scenario)
+    //    Example: In multi-host v6e-8, each host has a mesh of its local
+    //             devices, host 1 has devices TPU:{4, 5, 6, 7}. In this case
+    //             the global ordinal of TPU:4 is 0, TPU:5 is 1, and so on.
+                    
     int global_ordinal =
         ParseDeviceString(devices[i]).ordinal() % num_mesh_devices;
     device_index[global_ordinal] = i;

From e97401d4a24f203acb64147e473f084f726522ab Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Thu, 13 Mar 2025 06:27:08 +0000
Subject: [PATCH 15/16] assert on local devices in mesh contructor

---
 torch_xla/distributed/spmd/xla_sharding.py | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py
index 4bc0b71318f..90a6df67388 100644
--- a/torch_xla/distributed/spmd/xla_sharding.py
+++ b/torch_xla/distributed/spmd/xla_sharding.py
@@ -71,7 +71,15 @@ def __init__(self,
     self.mesh_shape = mesh_shape
     self.axis_names = axis_names
     # device ids are continous
-    assert all(d < self.size() for d in device_ids - np.min(device_ids))
+    if min(device_ids) != 0:
+      # Mesh doesn't contain all global devices. Only creating a mesh with local
+      # devices is supported.
+      min_device_idx = xr.process_index() * xr.addressable_runtime_device_count(
+      )
+      assert min_device_idx == min(
+          device_ids
+      ), "If not creating a mesh with all global devices, must use local devices."
+    assert all(d < self.size() for d in device_ids)
 
   def size(self):
     return np.prod(self.mesh_shape)

From 4bfd7b59a0512e8707d459e04de1ae9862550e2d Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Thu, 13 Mar 2025 06:36:55 +0000
Subject: [PATCH 16/16] check local devices in mesh ctor

---
 torch_xla/distributed/spmd/xla_sharding.py | 40 ++++++++++++----------
 1 file changed, 21 insertions(+), 19 deletions(-)

diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py
index 90a6df67388..d1a1db4b644 100644
--- a/torch_xla/distributed/spmd/xla_sharding.py
+++ b/torch_xla/distributed/spmd/xla_sharding.py
@@ -1,25 +1,25 @@
 import collections
-from collections.abc import Generator, MutableMapping
+import functools
+import itertools
 import math
+import os
 from collections import OrderedDict, defaultdict
+from collections.abc import Generator, MutableMapping
 from dataclasses import dataclass, field
+from enum import IntEnum
+from typing import Any, List, Optional, Sequence, Set, Tuple, Union
+
+import numpy as np
 import torch
-from torch import Tensor
-from torch.library import custom_op
 import torch_xla
-import torch_xla.core.xla_model as xm
 import torch_xla._internal.utils as _utils
-from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard
-import torch_xla.runtime as xr
+import torch_xla.core.xla_model as xm
 import torch_xla.debug.profiler as xp
-
-import numpy as np
-import functools
-import itertools
-from typing import Tuple, Union, List, Sequence, Any, Optional, Set
-from enum import IntEnum
-
-from torch.amp import custom_fwd, custom_bwd
+import torch_xla.runtime as xr
+from torch import Tensor
+from torch.amp import custom_bwd, custom_fwd
+from torch.library import custom_op
+from torch_xla.distributed.spmd import XLAShard, XLAShardedTensor
 
 
 class Mesh:
@@ -71,15 +71,16 @@ def __init__(self,
     self.mesh_shape = mesh_shape
     self.axis_names = axis_names
     # device ids are continous
-    if min(device_ids) != 0:
-      # Mesh doesn't contain all global devices. Only creating a mesh with local
-      # devices is supported.
+    if os.environ['XLA_USE_LOCAL_SPMD'] == '1':
+      # In local SPMD mesh only contains local devices.
       min_device_idx = xr.process_index() * xr.addressable_runtime_device_count(
       )
-      assert min_device_idx == min(
+      assert min_device_idx == np.min(
           device_ids
       ), "If not creating a mesh with all global devices, must use local devices."
-    assert all(d < self.size() for d in device_ids)
+      assert all(d < self.size() for d in device_ids - np.min(device_ids))
+    else:
+      assert all(d < self.size() for d in device_ids)
 
   def size(self):
     return np.prod(self.mesh_shape)
@@ -151,6 +152,7 @@ def __str__(self):
   def from_str(cls, mesh_str: str) -> Optional["Mesh"]:
     """Create Mesh from string representation."""
     import ast
+
     import numpy as np
     try:
       dict_str = mesh_str.replace('Mesh', '')