Skip to content

Commit c04f2da

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Delta tracker DMP integration (#3064)
Summary: Pull Request resolved: #3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable * Added hook registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR #3057 Differential Revision: D76202371
1 parent 269a7c9 commit c04f2da

File tree

7 files changed

+773
-33
lines changed

7 files changed

+773
-33
lines changed

torchrec/distributed/embedding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,13 +1514,17 @@ def compute_and_output_dist(
15141514
EmbeddingEvent.LOOKUP, self._module_fqn, sharding_type
15151515
):
15161516
embs = lookup(features)
1517+
if self.post_lookup_tracker_hook is not None:
1518+
self.post_lookup_tracker_hook(self, features, embs)
15171519

15181520
with maybe_annotate_embedding_event(
15191521
EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type
15201522
):
15211523
awaitables_per_sharding.append(
15221524
odist(embs.view(-1, embedding_dim), sharding_ctx)
15231525
)
1526+
if self.post_odist_tracker_hook is not None:
1527+
self.post_odist_tracker_hook()
15241528

15251529
features_before_all2all_per_sharding.append(
15261530
# pyre-fixme[6]: For 1st argument expected `KeyedJaggedTensor` but

torchrec/distributed/embedding_types.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
import abc
1111
import copy
12+
import logging as logger
1213
from dataclasses import dataclass
1314
from enum import Enum, unique
1415
from typing import (
1516
Any,
17+
Callable,
1618
Dict,
1719
Generic,
1820
Iterable,
@@ -370,6 +372,10 @@ def __init__(
370372
self._input_dists: List[nn.Module] = []
371373
self._lookups: List[nn.Module] = []
372374
self._output_dists: List[nn.Module] = []
375+
self.post_lookup_tracker_hook: Optional[
376+
Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]
377+
] = None
378+
self.post_odist_tracker_hook: Optional[Callable[..., None]] = None
373379

374380
def prefetch(
375381
self,
@@ -418,6 +424,41 @@ def train(self, mode: bool = True): # pyre-ignore[3]
418424

419425
return self
420426

427+
def register_post_lookup_tracker_hook(
428+
self,
429+
record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None],
430+
) -> None:
431+
"""
432+
Register a hook to be called after lookup is done. This is used for
433+
tracking the lookup results and optimizer states.
434+
435+
Args:
436+
record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
437+
438+
"""
439+
if self.post_lookup_tracker_hook is not None:
440+
logger.warning(
441+
"[ModelDeltaTracker] Custom record function already defined, overriding with new callable"
442+
)
443+
self.post_lookup_tracker_hook = record_fn
444+
445+
def register_post_odist_tracker_hook(
446+
self,
447+
record_fn: Callable[..., None],
448+
) -> None:
449+
"""
450+
Register a hook to be called after registering odist awaitable.
451+
452+
Args:
453+
record_fn (Callable[Callable[..., None]):
454+
455+
"""
456+
if self.post_odist_tracker_hook is not None:
457+
logger.warning(
458+
"[ModelDeltaTracker] Compaction function already defined, overriding with new callable"
459+
)
460+
self.post_odist_tracker_hook = record_fn
461+
421462
@property
422463
def unsharded_module_type(self) -> Type[nn.Module]:
423464
"""

torchrec/distributed/embeddingbag.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,13 +1458,17 @@ def compute_and_output_dist(
14581458
sharding_type,
14591459
):
14601460
embs = lookup(features)
1461+
if self.post_lookup_tracker_hook is not None:
1462+
self.post_lookup_tracker_hook(self, features, embs)
14611463

14621464
with maybe_annotate_embedding_event(
14631465
EmbeddingEvent.OUTPUT_DIST,
14641466
self._module_fqn,
14651467
sharding_type,
14661468
):
14671469
awaitables.append(dist(embs, sharding_context))
1470+
if self.post_odist_tracker_hook is not None:
1471+
self.post_odist_tracker_hook()
14681472

14691473
if sharding_context:
14701474
batch_size_per_feature_pre_a2a.extend(

torchrec/distributed/model_parallel.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
from torch.nn.modules.module import _IncompatibleKeys
3030
from torch.nn.parallel import DistributedDataParallel
3131
from torchrec.distributed.comm import get_local_size
32+
from torchrec.distributed.model_tracker.model_delta_tracker import (
33+
ModelDeltaTracker,
34+
SUPPORTED_MODULES,
35+
)
36+
from torchrec.distributed.model_tracker.types import DeltaRows, ModelTrackerConfig
3237

3338
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
3439
from torchrec.distributed.sharding_plan import get_default_sharders
@@ -208,6 +213,7 @@ class DistributedModelParallel(nn.Module, FusedOptimizerModule):
208213
init_parameters (bool): initialize parameters for modules still on meta device.
209214
data_parallel_wrapper (Optional[DataParallelWrapper]): custom wrapper for data
210215
parallel modules.
216+
model_tracker_config (Optional[DeltaTrackerConfig]): config for model tracker.
211217
212218
Example::
213219
@@ -234,6 +240,7 @@ def __init__(
234240
init_data_parallel: bool = True,
235241
init_parameters: bool = True,
236242
data_parallel_wrapper: Optional[DataParallelWrapper] = None,
243+
model_tracker_config: Optional[ModelTrackerConfig] = None,
237244
) -> None:
238245
super().__init__()
239246
torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}")
@@ -242,6 +249,11 @@ def __init__(
242249

243250
self._ddp_wrapped: bool = False
244251

252+
self.has_model_tracker: bool = model_tracker_config is not None
253+
254+
# List of callables to be executed before forward
255+
self._pre_forward_callables: List[Callable[..., None]] = []
256+
245257
if env is None:
246258
pg = dist.GroupMember.WORLD
247259
assert pg is not None, "Process group is not initialized"
@@ -286,6 +298,26 @@ def __init__(
286298
if init_data_parallel:
287299
self.init_data_parallel()
288300

301+
if model_tracker_config is not None:
302+
self.model_delta_tracker: ModelDeltaTracker = self._init_delta_tracker(
303+
model_tracker_config, self._dmp_wrapped_module
304+
)
305+
tracked_modules = self.model_delta_tracker.get_tracked_modules()
306+
for module in tracked_modules.values():
307+
if isinstance(module, SUPPORTED_MODULES):
308+
# register post lookup hook
309+
module.register_post_lookup_tracker_hook(
310+
self.model_delta_tracker.record_lookup
311+
)
312+
# register auto compaction hook at odist
313+
if model_tracker_config.auto_compact:
314+
module.register_post_odist_tracker_hook(
315+
self.model_delta_tracker.trigger_compaction
316+
)
317+
self._pre_forward_callables.append(
318+
self.model_delta_tracker.increment_batch_idx
319+
)
320+
289321
@property
290322
def module(self) -> nn.Module:
291323
"""
@@ -307,6 +339,11 @@ def module(self, value: nn.Module) -> None:
307339

308340
# pyre-ignore [2, 3]
309341
def forward(self, *args, **kwargs) -> Any:
342+
# Execute pre-forward callables. Currently used to update batch index
343+
# for model tracker.
344+
for callable_fn in self._pre_forward_callables:
345+
callable_fn()
346+
310347
return self._dmp_wrapped_module(*args, **kwargs)
311348

312349
def init_data_parallel(self) -> None:
@@ -344,6 +381,19 @@ def copy(
344381
def _init_dmp(self, module: nn.Module) -> nn.Module:
345382
return self._shard_modules_impl(module)
346383

384+
def _init_delta_tracker(
385+
self, model_tracker_config: ModelTrackerConfig, module: nn.Module
386+
) -> ModelDeltaTracker:
387+
# Init delta tracker if config is provided
388+
return ModelDeltaTracker(
389+
model=module,
390+
consumers=model_tracker_config.consumers,
391+
delete_on_read=model_tracker_config.delete_on_read,
392+
auto_compact=model_tracker_config.auto_compact,
393+
mode=model_tracker_config.tracking_mode,
394+
fqns_to_skip=model_tracker_config.fqns_to_skip,
395+
)
396+
347397
def _init_optim(self, module: nn.Module) -> CombinedOptimizer:
348398
# pyre-ignore [6]
349399
return CombinedOptimizer(self._fused_optim_impl(module, []))
@@ -421,6 +471,25 @@ def init_parameters(module: nn.Module) -> None:
421471

422472
module.apply(init_parameters)
423473

474+
def get_model_tracker(self) -> ModelDeltaTracker:
475+
"""
476+
Returns the model tracker if it exists.
477+
"""
478+
479+
assert (
480+
self.has_model_tracker
481+
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
482+
return self.model_delta_tracker
483+
484+
def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
485+
"""
486+
Returns the delta rows for the given consumer.
487+
"""
488+
assert (
489+
self.has_model_tracker
490+
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
491+
return self.model_delta_tracker.get_delta(consumer)
492+
424493
def sparse_grad_parameter_names(
425494
self, destination: Optional[List[str]] = None, prefix: str = ""
426495
) -> List[str]:

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99
import logging as logger
1010
from collections import Counter, OrderedDict
11-
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
11+
from typing import Dict, Iterable, List, Optional, Union
1212

1313
import torch
1414

@@ -105,6 +105,37 @@ def __init__(
105105
self.feature_to_fqn[feature_name] = fqn
106106
logger.info(f"feature_to_fqn: {self.feature_to_fqn}")
107107

108+
def increment_batch_idx(self) -> None:
109+
self.curr_batch_idx += 1
110+
111+
def trigger_compaction(self) -> None:
112+
if self.curr_compact_index >= self.curr_batch_idx:
113+
# only trigger compaction once per iteration
114+
return
115+
116+
self.curr_compact_index += 1
117+
# TODO: May need to revisit the compaction logic with multiple consmers.
118+
# At present we take the max per_consumer_batch_idx to ensure we only compact
119+
# newely received lookups
120+
121+
# The trigger_compaction() function is expected to overlap with comms to hide
122+
# compaction compute overhead. Currently, we overlap compaction with odist
123+
# because ID tracking occurs during local embedding lookup, which takes place
124+
# before odist. This way, auto_compact always merges all past IDs tensors since
125+
# the last get_delta call into a single IDs tensor per FQN.
126+
#
127+
# For delete_on_read=True, get_delta() should delete up to per_consumer_batch_idx
128+
# (exclusive). So the compaction should start from per_consumer_batch_idx.
129+
#
130+
# For delete_on_read=False, get_delta() won't delete tensors, but it does advance
131+
# per_consumer_batch_idx accordingly, where all ids prior to per_consumer_batch_idx (exclusive)
132+
# should have been compacted into one tensor regardless of auto_compact=True/False.
133+
# Therefore, all future compactions should start from per_consumer_batch_idx.
134+
start_idx = max(self.per_consumer_batch_idx.values())
135+
end_idx = self.curr_batch_idx
136+
if start_idx < end_idx:
137+
self.compact(start_idx=start_idx, end_idx=end_idx)
138+
108139
def record_lookup(
109140
self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor
110141
) -> None:
@@ -334,8 +365,13 @@ def compact(self, start_idx: int, end_idx: int) -> None:
334365
self.store.compact(start_idx, end_idx)
335366

336367
def _clean_fqn_fn(self, fqn: str) -> str:
337-
# strip DMP internal module FQN prefix to match state dict FQN
338-
return fqn.replace("_dmp_wrapped_module.module.", "")
368+
# strip FQN prefixes added by DMP and other TorchRec operations to match state dict FQN
369+
# handles both "_dmp_wrapped_module.module." and "module." prefixes
370+
prefixes_to_strip = ["_dmp_wrapped_module.module.", "module."]
371+
for prefix in prefixes_to_strip:
372+
if fqn.startswith(prefix):
373+
return fqn[len(prefix) :]
374+
return fqn
339375

340376
def _validate_mode(self) -> None:
341377
"To validate the mode is supported for the given module"

0 commit comments

Comments
 (0)