1111import contextlib
1212import logging
1313from collections import deque
14+ from contextlib import nullcontext
1415from dataclasses import dataclass
1516from typing import (
1617 Any ,
@@ -319,6 +320,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
319320 return output
320321
321322
323+ _apply_jit_context_default : ContextManager [None ] = nullcontext ()
324+
325+
322326class TrainPipelineSparseDist (TrainPipeline [In , Out ]):
323327 """
324328 This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
@@ -344,6 +348,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
344348 execute_all_batches (bool): executes remaining batches in pipeline after
345349 exhausting dataloader iterator.
346350 apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
351+ apply_jit_context (ContextManager): a context manager that will surround the
352+ application of the JIT
347353 """
348354
349355 # The PipelinedForward class that is used in _rewrite_model
@@ -362,12 +368,14 @@ def __init__(
362368 custom_model_fwd : Optional [
363369 Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
364370 ] = None ,
371+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
365372 ) -> None :
366373 self ._model = model
367374 self ._optimizer = optimizer
368375 self ._device = device
369376 self ._execute_all_batches = execute_all_batches
370377 self ._apply_jit = apply_jit
378+ self ._apply_jit_context = apply_jit_context
371379
372380 if device .type == "cuda" :
373381 # use two data streams to support two concurrent batches
@@ -643,6 +651,7 @@ def _pipeline_model(
643651 apply_jit = self ._apply_jit ,
644652 pipelined_forward = pipelined_forward ,
645653 pipeline_postproc = self ._pipeline_postproc ,
654+ apply_jit_context = self ._apply_jit_context ,
646655 )
647656 # initializes input dist, so we can override input dist forwards
648657 self .start_sparse_data_dist (batch , context )
@@ -993,6 +1002,8 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
9931002 start_batch (int): batch to begin semi-sync training. Typically small period of synchronous training reduces early stage NEX.
9941003 stash_gradients (bool): if True, will store gradients for each parameter to insure true "Semi-Sync"
9951004 training. If False, will update dense optimizer as soon as gradients available (naive "Semi-Sync)
1005+ apply_jit_context (ContextManager): a context manager that will surround the
1006+ application of the JIT
9961007 """
9971008
9981009 # The PipelinedForward class that is used in _rewrite_model
@@ -1012,6 +1023,7 @@ def __init__(
10121023 Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
10131024 ] = None ,
10141025 strict : bool = False ,
1026+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
10151027 ) -> None :
10161028 super ().__init__ (
10171029 model = model ,
@@ -1022,6 +1034,7 @@ def __init__(
10221034 context_type = EmbeddingTrainPipelineContext ,
10231035 pipeline_postproc = pipeline_postproc ,
10241036 custom_model_fwd = custom_model_fwd ,
1037+ apply_jit_context = apply_jit_context ,
10251038 )
10261039 self ._start_batch = start_batch
10271040 self ._stash_gradients = stash_gradients
@@ -1305,6 +1318,8 @@ class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
13051318 execute_all_batches (bool): executes remaining batches in pipeline after
13061319 exhausting dataloader iterator.
13071320 apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1321+ apply_jit_context (ContextManager): a context manager that will surround the
1322+ application of the JIT
13081323 """
13091324
13101325 # The PipelinedForward class that is used in _rewrite_model
@@ -1321,6 +1336,7 @@ def __init__(
13211336 custom_model_fwd : Optional [
13221337 Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
13231338 ] = None ,
1339+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
13241340 ) -> None :
13251341 super ().__init__ (
13261342 model = model ,
@@ -1331,6 +1347,7 @@ def __init__(
13311347 context_type = PrefetchTrainPipelineContext ,
13321348 pipeline_postproc = pipeline_postproc ,
13331349 custom_model_fwd = custom_model_fwd ,
1350+ apply_jit_context = apply_jit_context ,
13341351 )
13351352 self ._context = PrefetchTrainPipelineContext (version = 0 )
13361353 self ._prefetch_stream : Optional [torch .Stream ] = (
@@ -1462,6 +1479,8 @@ class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
14621479 device (torch.device): device where device transfer, sparse data dist, and
14631480 forward/backward pass will happen.
14641481 apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
1482+ apply_jit_context (ContextManager): a context manager that will surround the
1483+ application of the JIT
14651484 """
14661485
14671486 # The PipelinedForward class that is used in _rewrite_model
@@ -1473,8 +1492,16 @@ def __init__(
14731492 optimizer : torch .optim .Optimizer ,
14741493 device : torch .device ,
14751494 apply_jit : bool = False ,
1495+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
14761496 ) -> None :
1477- super ().__init__ (model , optimizer , device , True , apply_jit )
1497+ super ().__init__ (
1498+ model ,
1499+ optimizer ,
1500+ device ,
1501+ True ,
1502+ apply_jit ,
1503+ apply_jit_context = apply_jit_context ,
1504+ )
14781505 self ._batch_loader : Optional [DataLoadingThread [In ]] = None
14791506
14801507 def __del__ (self ) -> None :
@@ -1836,6 +1863,7 @@ def __init__(
18361863 custom_model_fwd : Optional [
18371864 Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
18381865 ] = None ,
1866+ apply_jit_context : ContextManager [None ] = _apply_jit_context_default ,
18391867 ) -> None :
18401868 super ().__init__ (
18411869 model ,
@@ -1846,6 +1874,7 @@ def __init__(
18461874 context_type ,
18471875 pipeline_postproc ,
18481876 custom_model_fwd ,
1877+ apply_jit_context = apply_jit_context ,
18491878 )
18501879
18511880 torch ._logging .set_logs (compiled_autograd_verbose = True )
0 commit comments