Skip to content

Commit 14acab7

Browse files
authored
Fix mark_sharding logic (#8578)
1 parent 79de480 commit 14acab7

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

test/tpu/run_tests.sh

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ python3 "$TEST_CDIR/scan/test_scan_layers.py"
3636
run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
3737
python3 "$TEST_CDIR/test_pallas.py" -v
3838
python3 "$TEST_CDIR/test_pallas_spmd.py"
39+
XLA_DISABLE_FUNCTIONALIZATION=1 python3 "$TEST_CDIR/test_pallas_spmd.py"
3940
python3 "$TEST_CDIR/test_tpu_paged_attention_kernel.py"
4041
python3 "$TEST_CDIR/test_input_output_aliases.py"
4142
python3 "$TEST_CDIR/test_gmm.py"

torch_xla/csrc/tensor_methods.cpp

+19-2
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,25 @@ void custom_sharding_(
607607
const XLATensorPtr& input,
608608
const std::shared_ptr<XLATensor::ShardingSpec>& sharding_spec,
609609
const CustomSharding::Type& type) {
610-
input->SetInPlaceIrValue(torch_xla::MakeNode<CustomSharding>(
611-
input->GetIrValue(), input->shape().get(), type));
610+
torch::lazy::NodePtr customShardingNode = torch_xla::MakeNode<CustomSharding>(
611+
input->GetIrValue(), input->shape().get(), type);
612+
XlaNode* xla_node = dynamic_cast<XlaNode*>(customShardingNode.get());
613+
// Always call `SetSharding` to ensure the `CustomSharding` op has the correct
614+
// sharding, especially if a view is updated afterward. Updating a view can
615+
// modify the IR, potentially leading to the sharding being applied to the
616+
// updated view instead of the original `CustomSharding` op.
617+
618+
// For example, consider the following IR:
619+
// ```
620+
// x0 = custom_sharding(input)
621+
// x1 = view_update(x0)
622+
// ```
623+
// In this case, we want to ensure the sharding is applied to `x0`, not `x1`.
624+
625+
// While this solution may add a sharding spec to non-CustomSharding ops like
626+
// `x1`, the XLA compiler will safely ignore it.
627+
xla_node->SetSharding(sharding_spec->sharding, 0);
628+
input->SetInPlaceIrValue(customShardingNode);
612629
input->SetShardingSpec(*sharding_spec);
613630
}
614631

0 commit comments

Comments
 (0)