Skip to content

Delta tracker DMP integration #3064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,13 +1514,17 @@ def compute_and_output_dist(
EmbeddingEvent.LOOKUP, self._module_fqn, sharding_type
):
embs = lookup(features)
if self.post_lookup_tracker_hook is not None:
self.post_lookup_tracker_hook(self, features, embs)

with maybe_annotate_embedding_event(
EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type
):
awaitables_per_sharding.append(
odist(embs.view(-1, embedding_dim), sharding_ctx)
)
if self.post_odist_tracker_hook is not None:
self.post_odist_tracker_hook()

features_before_all2all_per_sharding.append(
# pyre-fixme[6]: For 1st argument expected `KeyedJaggedTensor` but
Expand Down
41 changes: 41 additions & 0 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

import abc
import copy
import logging as logger
from dataclasses import dataclass
from enum import Enum, unique
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
Expand Down Expand Up @@ -370,6 +372,10 @@ def __init__(
self._input_dists: List[nn.Module] = []
self._lookups: List[nn.Module] = []
self._output_dists: List[nn.Module] = []
self.post_lookup_tracker_hook: Optional[
Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]
] = None
self.post_odist_tracker_hook: Optional[Callable[..., None]] = None

def prefetch(
self,
Expand Down Expand Up @@ -418,6 +424,41 @@ def train(self, mode: bool = True): # pyre-ignore[3]

return self

def register_post_lookup_tracker_hook(
self,
record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None],
) -> None:
"""
Register a hook to be called after lookup is done. This is used for
tracking the lookup results and optimizer states.

Args:
record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.

"""
if self.post_lookup_tracker_hook is not None:
logger.warning(
"[ModelDeltaTracker] Custom record function already defined, overriding with new callable"
)
self.post_lookup_tracker_hook = record_fn

def register_post_odist_tracker_hook(
self,
record_fn: Callable[..., None],
) -> None:
"""
Register a hook to be called after registering odist awaitable.

Args:
record_fn (Callable[Callable[..., None]):

"""
if self.post_odist_tracker_hook is not None:
logger.warning(
"[ModelDeltaTracker] Compaction function already defined, overriding with new callable"
)
self.post_odist_tracker_hook = record_fn

@property
def unsharded_module_type(self) -> Type[nn.Module]:
"""
Expand Down
4 changes: 4 additions & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,13 +1458,17 @@ def compute_and_output_dist(
sharding_type,
):
embs = lookup(features)
if self.post_lookup_tracker_hook is not None:
self.post_lookup_tracker_hook(self, features, embs)

with maybe_annotate_embedding_event(
EmbeddingEvent.OUTPUT_DIST,
self._module_fqn,
sharding_type,
):
awaitables.append(dist(embs, sharding_context))
if self.post_odist_tracker_hook is not None:
self.post_odist_tracker_hook()

if sharding_context:
batch_size_per_feature_pre_a2a.extend(
Expand Down
69 changes: 69 additions & 0 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.model_tracker.model_delta_tracker import (
ModelDeltaTracker,
SUPPORTED_MODULES,
)
from torchrec.distributed.model_tracker.types import DeltaRows, ModelTrackerConfig

from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.sharding_plan import get_default_sharders
Expand Down Expand Up @@ -208,6 +213,7 @@ class DistributedModelParallel(nn.Module, FusedOptimizerModule):
init_parameters (bool): initialize parameters for modules still on meta device.
data_parallel_wrapper (Optional[DataParallelWrapper]): custom wrapper for data
parallel modules.
model_tracker_config (Optional[DeltaTrackerConfig]): config for model tracker.

Example::

Expand All @@ -234,6 +240,7 @@ def __init__(
init_data_parallel: bool = True,
init_parameters: bool = True,
data_parallel_wrapper: Optional[DataParallelWrapper] = None,
model_tracker_config: Optional[ModelTrackerConfig] = None,
) -> None:
super().__init__()
torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}")
Expand All @@ -242,6 +249,11 @@ def __init__(

self._ddp_wrapped: bool = False

self.has_model_tracker: bool = model_tracker_config is not None

# List of callables to be executed before forward
self._pre_forward_callables: List[Callable[..., None]] = []

if env is None:
pg = dist.GroupMember.WORLD
assert pg is not None, "Process group is not initialized"
Expand Down Expand Up @@ -286,6 +298,26 @@ def __init__(
if init_data_parallel:
self.init_data_parallel()

if model_tracker_config is not None:
self.model_delta_tracker: ModelDeltaTracker = self._init_delta_tracker(
model_tracker_config, self._dmp_wrapped_module
)
tracked_modules = self.model_delta_tracker.get_tracked_modules()
for module in tracked_modules.values():
if isinstance(module, SUPPORTED_MODULES):
# register post lookup hook
module.register_post_lookup_tracker_hook(
self.model_delta_tracker.record_lookup
)
# register auto compaction hook at odist
if model_tracker_config.auto_compact:
module.register_post_odist_tracker_hook(
self.model_delta_tracker.trigger_compaction
)
self._pre_forward_callables.append(
self.model_delta_tracker.increment_batch_idx
)

@property
def module(self) -> nn.Module:
"""
Expand All @@ -307,6 +339,11 @@ def module(self, value: nn.Module) -> None:

# pyre-ignore [2, 3]
def forward(self, *args, **kwargs) -> Any:
# Execute pre-forward callables. Currently used to update batch index
# for model tracker.
for callable_fn in self._pre_forward_callables:
callable_fn()

return self._dmp_wrapped_module(*args, **kwargs)

def init_data_parallel(self) -> None:
Expand Down Expand Up @@ -344,6 +381,19 @@ def copy(
def _init_dmp(self, module: nn.Module) -> nn.Module:
return self._shard_modules_impl(module)

def _init_delta_tracker(
self, model_tracker_config: ModelTrackerConfig, module: nn.Module
) -> ModelDeltaTracker:
# Init delta tracker if config is provided
return ModelDeltaTracker(
model=module,
consumers=model_tracker_config.consumers,
delete_on_read=model_tracker_config.delete_on_read,
auto_compact=model_tracker_config.auto_compact,
mode=model_tracker_config.tracking_mode,
fqns_to_skip=model_tracker_config.fqns_to_skip,
)

def _init_optim(self, module: nn.Module) -> CombinedOptimizer:
# pyre-ignore [6]
return CombinedOptimizer(self._fused_optim_impl(module, []))
Expand Down Expand Up @@ -421,6 +471,25 @@ def init_parameters(module: nn.Module) -> None:

module.apply(init_parameters)

def get_model_tracker(self) -> ModelDeltaTracker:
"""
Returns the model tracker if it exists.
"""

assert (
self.has_model_tracker
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
return self.model_delta_tracker

def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
"""
Returns the delta rows for the given consumer.
"""
assert (
self.has_model_tracker
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
return self.model_delta_tracker.get_delta(consumer)

def sparse_grad_parameter_names(
self, destination: Optional[List[str]] = None, prefix: str = ""
) -> List[str]:
Expand Down
Loading
Loading