@@ -607,8 +607,25 @@ void custom_sharding_(
607
607
const XLATensorPtr& input,
608
608
const std::shared_ptr<XLATensor::ShardingSpec>& sharding_spec,
609
609
const CustomSharding::Type& type) {
610
- input->SetInPlaceIrValue (torch_xla::MakeNode<CustomSharding>(
611
- input->GetIrValue (), input->shape ().get (), type));
610
+ torch::lazy::NodePtr customShardingNode = torch_xla::MakeNode<CustomSharding>(
611
+ input->GetIrValue (), input->shape ().get (), type);
612
+ XlaNode* xla_node = dynamic_cast <XlaNode*>(customShardingNode.get ());
613
+ // Always call `SetSharding` to ensure the `CustomSharding` op has the correct
614
+ // sharding, especially if a view is updated afterward. Updating a view can
615
+ // modify the IR, potentially leading to the sharding being applied to the
616
+ // updated view instead of the original `CustomSharding` op.
617
+
618
+ // For example, consider the following IR:
619
+ // ```
620
+ // x0 = custom_sharding(input)
621
+ // x1 = view_update(x0)
622
+ // ```
623
+ // In this case, we want to ensure the sharding is applied to `x0`, not `x1`.
624
+
625
+ // While this solution may add a sharding spec to non-CustomSharding ops like
626
+ // `x1`, the XLA compiler will safely ignore it.
627
+ xla_node->SetSharding (sharding_spec->sharding , 0 );
628
+ input->SetInPlaceIrValue (customShardingNode);
612
629
input->SetShardingSpec (*sharding_spec);
613
630
}
614
631
0 commit comments