Skip to content

Commit b9ebfbb

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Update KJT stride calculation logic to be based off of inverse_indices for VBE KJTs.
Summary: Update the `_maybe_compute_stride_kjt` logic to calculate stride based off of `inverse_indices` for VBE KJTs. Currently, stride of VBE KJT with `stride_per_key_per_rank` is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of inverse_indices. This causes issues in IR module serialization: debug doc. Differential Revision: D74273083
1 parent 3e2737e commit b9ebfbb

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

torchrec/sparse/jagged_tensor.py

+4
Original file line numberDiff line numberDiff line change
@@ -1097,10 +1097,13 @@ def _maybe_compute_stride_kjt(
10971097
lengths: Optional[torch.Tensor],
10981098
offsets: Optional[torch.Tensor],
10991099
stride_per_key_per_rank: Optional[List[List[int]]],
1100+
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],
11001101
) -> int:
11011102
if stride is None:
11021103
if len(keys) == 0:
11031104
stride = 0
1105+
elif inverse_indices is not None and inverse_indices[1].numel() > 0:
1106+
return inverse_indices[1].shape[1]
11041107
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
11051108
stride = max([sum(s) for s in stride_per_key_per_rank])
11061109
elif offsets is not None and offsets.numel() > 0:
@@ -2165,6 +2168,7 @@ def stride(self) -> int:
21652168
self._lengths,
21662169
self._offsets,
21672170
self._stride_per_key_per_rank,
2171+
self._inverse_indices,
21682172
)
21692173
self._stride = stride
21702174
return stride

torchrec/sparse/tests/test_keyed_jagged_tensor.py

+12
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,18 @@ def test_meta_device_compatibility(self) -> None:
10171017
lengths=torch.tensor([], device=torch.device("meta")),
10181018
)
10191019

1020+
def test_vbe_kjt_stride(self) -> None:
1021+
inverse_indices = torch.tensor([[0, 1, 0], [0, 0, 0]])
1022+
kjt = KeyedJaggedTensor(
1023+
keys=["f1", "f2", "f3"],
1024+
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
1025+
lengths=torch.tensor([3, 3, 2]),
1026+
stride_per_key_per_rank=[[2], [1]],
1027+
inverse_indices=(["f1", "f2"], inverse_indices),
1028+
)
1029+
1030+
self.assertEqual(kjt.stride(), inverse_indices.shape[1])
1031+
10201032

10211033
class TestKeyedJaggedTensorScripting(unittest.TestCase):
10221034
def test_scriptable_forward(self) -> None:

0 commit comments

Comments
 (0)