Skip to content

Commit e38a01e

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Include stride_per_key_per_rank in KJT's PyTree flatten/unflatten logic (#2903)
Summary: # Context * Currently torchrec IR serializer can't handle variable batch use case. * `torch.export` only captures the keys, values, lengths, weights, offsets of a KJT, however, some variable-batch related parameters like `stride_per_rank` or `inverse_indices` would be ignored. * This test case (expected failure right now) covers the vb-KJT scenario for verifying that the serialize_deserialize_ebc use case works fine with KJTs with variable batch size. # Ref Differential Revision: D73051959
1 parent 33aeafa commit e38a01e

File tree

4 files changed

+64
-32
lines changed

4 files changed

+64
-32
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,6 @@ def test_serialize_deserialize_ebc(self) -> None:
292292
self.assertEqual(deserialized.shape, orginal.shape)
293293
self.assertTrue(torch.allclose(deserialized, orginal))
294294

295-
@unittest.skip("Adding test for demonstrating VBE KJT flattening issue for now.")
296295
def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
297296
model = self.generate_model_for_vbe_kjt()
298297
id_list_features = KeyedJaggedTensor(
@@ -319,15 +318,16 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
319318
# Run forward on ExportedProgram
320319
ep_output = ep.module()(id_list_features)
321320

321+
self.assertEqual(len(ep_output), len(id_list_features.keys()))
322322
for i, tensor in enumerate(ep_output):
323-
self.assertEqual(eager_out[i].shape, tensor.shape)
323+
self.assertEqual(eager_out[i].shape[1], tensor.shape[1])
324324

325325
# Deserialize EBC
326326
unflatten_ep = torch.export.unflatten(ep)
327327
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
328328

329329
# check EBC config
330-
for i in range(5):
330+
for i in range(1):
331331
ebc_name = f"ebc{i + 1}"
332332
self.assertIsInstance(
333333
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
@@ -342,29 +342,9 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
342342
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
343343
self.assertEqual(deserialized.feature_names, orginal.feature_names)
344344

345-
# check FPEBC config
346-
for i in range(2):
347-
fpebc_name = f"fpebc{i + 1}"
348-
assert isinstance(
349-
getattr(deserialized_model, fpebc_name),
350-
FeatureProcessedEmbeddingBagCollection,
351-
)
352-
353-
for deserialized, orginal in zip(
354-
getattr(
355-
deserialized_model, fpebc_name
356-
)._embedding_bag_collection.embedding_bag_configs(),
357-
getattr(
358-
model, fpebc_name
359-
)._embedding_bag_collection.embedding_bag_configs(),
360-
):
361-
self.assertEqual(deserialized.name, orginal.name)
362-
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
363-
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
364-
self.assertEqual(deserialized.feature_names, orginal.feature_names)
365-
366345
# Run forward on deserialized model and compare the output
367346
deserialized_model.load_state_dict(model.state_dict())
347+
368348
deserialized_out = deserialized_model(id_list_features)
369349

370350
self.assertEqual(len(deserialized_out), len(eager_out))
@@ -385,6 +365,7 @@ def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
385365
values=torch.tensor([0, 1, 2, 3, 2, 3, 4]),
386366
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
387367
)
368+
388369
eager_out = model(feature2)
389370

390371
# Serialize EBC

torchrec/modules/embedding_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def reorder_inverse_indices(
2727
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],
2828
feature_names: List[str],
2929
) -> torch.Tensor:
30-
if inverse_indices is None:
30+
if inverse_indices is None or inverse_indices[1].numel() == 0:
3131
return torch.empty(0)
3232
index_per_name = {name: i for i, name in enumerate(inverse_indices[0])}
3333
index = torch.tensor(

torchrec/sparse/jagged_tensor.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
# pyre-strict
99

1010
import abc
11+
import dataclasses
1112
import logging
1213

1314
import operator
15+
from dataclasses import dataclass
1416

1517
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1618

@@ -1756,6 +1758,7 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
17561758
"_weights",
17571759
"_lengths",
17581760
"_offsets",
1761+
"_inverse_indices_tensor",
17591762
]
17601763

17611764
def __init__(
@@ -1800,6 +1803,9 @@ def __init__(
18001803
self._inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = (
18011804
inverse_indices
18021805
)
1806+
self._inverse_indices_tensor: Optional[torch.Tensor] = torch.empty(0)
1807+
if inverse_indices is not None:
1808+
self._inverse_indices_tensor = inverse_indices[1]
18031809

18041810
# legacy attribute, for backward compatabilibity
18051811
self._variable_stride_per_key: Optional[bool] = None
@@ -3030,15 +3036,32 @@ def dist_init(
30303036
return kjt.sync()
30313037

30323038

3039+
@dataclass
3040+
class KjtTreeSpecs:
3041+
keys: List[str]
3042+
stride_per_key_per_rank: Optional[List[List[int]]]
3043+
3044+
def to_dict(self) -> dict[str, Any]:
3045+
return {
3046+
field.name: getattr(self, field.name) for field in dataclasses.fields(self)
3047+
}
3048+
3049+
30333050
def _kjt_flatten(
30343051
t: KeyedJaggedTensor,
3035-
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:
3036-
return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys
3052+
) -> Tuple[List[Optional[torch.Tensor]], Tuple[List[str], Optional[List[List[int]]]]]:
3053+
return [getattr(t, a) for a in KeyedJaggedTensor._fields], (
3054+
t._keys,
3055+
t._stride_per_key_per_rank,
3056+
)
30373057

30383058

30393059
def _kjt_flatten_with_keys(
30403060
t: KeyedJaggedTensor,
3041-
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], List[str]]:
3061+
) -> Tuple[
3062+
List[Tuple[KeyEntry, Optional[torch.Tensor]]],
3063+
Tuple[List[str], Optional[List[List[int]]]],
3064+
]:
30423065
values, context = _kjt_flatten(t)
30433066
# pyre can't tell that GetAttrKey implements the KeyEntry protocol
30443067
return [ # pyre-ignore[7]
@@ -3047,9 +3070,17 @@ def _kjt_flatten_with_keys(
30473070

30483071

30493072
def _kjt_unflatten(
3050-
values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys
3073+
values: List[Optional[torch.Tensor]],
3074+
context: Tuple[
3075+
List[str], Optional[List[List[int]]]
3076+
], # context is the (_keys, _stride_per_key_per_rank, _inverse_indices) tuple
30513077
) -> KeyedJaggedTensor:
3052-
return KeyedJaggedTensor(context, *values)
3078+
return KeyedJaggedTensor(
3079+
context[0],
3080+
*values[:-1],
3081+
stride_per_key_per_rank=context[1],
3082+
inverse_indices=(context[0], values[-1]),
3083+
)
30533084

30543085

30553086
def _kjt_flatten_spec(
@@ -3070,7 +3101,9 @@ def _kjt_flatten_spec(
30703101

30713102
def flatten_kjt_list(
30723103
kjt_arr: List[KeyedJaggedTensor],
3073-
) -> Tuple[List[Optional[torch.Tensor]], List[List[str]]]:
3104+
) -> Tuple[
3105+
List[Optional[torch.Tensor]], List[Tuple[List[str], Optional[List[List[int]]]]]
3106+
]:
30743107
_flattened_data = []
30753108
_flattened_context = []
30763109
for t in kjt_arr:
@@ -3081,7 +3114,8 @@ def flatten_kjt_list(
30813114

30823115

30833116
def unflatten_kjt_list(
3084-
values: List[Optional[torch.Tensor]], contexts: List[List[str]]
3117+
values: List[Optional[torch.Tensor]],
3118+
contexts: List[Tuple[List[str], Optional[List[List[int]]]]],
30853119
) -> List[KeyedJaggedTensor]:
30863120
num_kjt_fields = len(KeyedJaggedTensor._fields)
30873121
length = len(values)

torchrec/sparse/tests/test_keyed_jagged_tensor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,23 @@ def test_meta_device_compatibility(self) -> None:
10171017
lengths=torch.tensor([], device=torch.device("meta")),
10181018
)
10191019

1020+
def test_flatten_unflatten_with_vbe(self) -> None:
1021+
kjt = KeyedJaggedTensor(
1022+
keys=["f1", "f2"],
1023+
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
1024+
lengths=torch.tensor([3, 3, 2]),
1025+
stride_per_key_per_rank=[[2], [1]],
1026+
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
1027+
)
1028+
1029+
flat_kjt, spec = pytree.tree_flatten(kjt)
1030+
unflattened_kjt = pytree.tree_unflatten(flat_kjt, spec)
1031+
1032+
self.assertEqual(
1033+
kjt.stride_per_key_per_rank(), unflattened_kjt.stride_per_key_per_rank()
1034+
)
1035+
self.assertEqual(kjt.inverse_indices(), unflattened_kjt.inverse_indices())
1036+
10201037

10211038
class TestKeyedJaggedTensorScripting(unittest.TestCase):
10221039
def test_scriptable_forward(self) -> None:

0 commit comments

Comments
 (0)