Skip to content

{Tests-WIP}[Torchrec] Add context manager to use next batch context for postprocs #2939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 12 additions & 28 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
StageOut,
StageOutputWithEvent,
TrainPipelineContext,
use_context_for_postprocs,
)
from torchrec.distributed.types import Awaitable
from torchrec.pt2.checks import is_torchdynamo_compiling
Expand Down Expand Up @@ -792,19 +793,9 @@ def start_sparse_data_dist(
with self._stream_context(self._data_dist_stream):
_wait_for_batch(batch, self._memcpy_stream)

original_contexts = [p.get_context() for p in self._pipelined_postprocs]

# Temporarily set context for next iter to populate cache
for postproc_mod in self._pipelined_postprocs:
postproc_mod.set_context(context)

_start_data_dist(self._pipelined_modules, batch, context)

# Restore context for model fwd
for module, context in zip(
self._pipelined_postprocs, original_contexts
):
module.set_context(context)
with use_context_for_postprocs(self._pipelined_postprocs, context):
_start_data_dist(self._pipelined_modules, batch, context)

def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None:
"""
Expand Down Expand Up @@ -1325,22 +1316,15 @@ def start_sparse_data_dist(
return

# Temporarily set context for next iter to populate cache
original_contexts = [p.get_context() for p in self._pipelined_postprocs]
for postproc_mod in self._pipelined_postprocs:
postproc_mod.set_context(context)

with record_function(f"## start_sparse_data_dist {context.index} ##"):
with self._stream_context(self._data_dist_stream):
_wait_for_events(batch, context, self._data_dist_stream)
model_input = self.extract_model_input_from_batch(batch)
_start_data_dist(self._pipelined_modules, model_input, context)
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)

# Restore context for model forward
for module, context in zip(self._pipelined_postprocs, original_contexts):
module.set_context(context)
with use_context_for_postprocs(self._pipelined_postprocs, context):
with record_function(f"## start_sparse_data_dist {context.index} ##"):
with self._stream_context(self._data_dist_stream):
_wait_for_events(batch, context, self._data_dist_stream)
model_input = self.extract_model_input_from_batch(batch)
_start_data_dist(self._pipelined_modules, model_input, context)
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)

def start_embedding_lookup(
self,
Expand Down
25 changes: 25 additions & 0 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict

import contextlib
import copy
import itertools
import logging
Expand All @@ -21,6 +22,7 @@
Callable,
cast,
Dict,
Generator,
Generic,
Iterable,
Iterator,
Expand Down Expand Up @@ -1797,6 +1799,28 @@ def _prefetch_embeddings(
return data_per_sharded_module


@contextlib.contextmanager
def use_context_for_postprocs(
pipelined_postprocs: List[PipelinedPostproc],
next_batch_context: TrainPipelineContext,
) -> Generator[None, None, None]:
"""
Temporarily set pipelined postproc context for next iter to populate cache.
"""
# Save original context for model fwd
original_contexts = [p.get_context() for p in pipelined_postprocs]

# Temporarily set context for next iter to populate cache
for postproc_mod in pipelined_postprocs:
postproc_mod.set_context(next_batch_context)

yield

# Restore context for model fwd
for module, context in zip(pipelined_postprocs, original_contexts):
module.set_context(context)


class SparseDataDistUtil(Generic[In]):
"""
Helper class exposing methods for sparse data dist and prefetch pipelining.
Expand All @@ -1808,6 +1832,7 @@ class SparseDataDistUtil(Generic[In]):
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
prefetch_stream (Optional[torch.cuda.Stream]): Stream on which model prefetch runs
Defaults to `None`. This needs to be passed in to enable prefetch pipelining.
pipeline_postproc (bool): whether to pipeline postproc modules. Defaults to `False`.

Example::
sdd = SparseDataDistUtil(
Expand Down
Loading