diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 4a2d9d561..a96fdd1ca 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -359,6 +359,14 @@ def __init__(self): If using looped schedules, this still specifies the number of physical ranks, not the number of stages. Stages per rank are inferred from split points degree, and schedule.""", ) + self.parser.add_argument( + "--experimental.pipeline_parallel_batch_split_dim", + type=int, + default=0, + help=""" + The dimension to split the batch on for pipeline parallelism. Defaults to 0 (batch dimension), but can + also be set to 1 (sequence dimension).""", + ) self.parser.add_argument( "--experimental.pipeline_parallel_split_points", type=string_list, diff --git a/torchtitan/distributed/pipeline.py b/torchtitan/distributed/pipeline.py index e471a6700..6c242320f 100644 --- a/torchtitan/distributed/pipeline.py +++ b/torchtitan/distributed/pipeline.py @@ -6,6 +6,8 @@ import os from typing import Callable +from torch.distributed.pipelining.microbatch import TensorChunkSpec + from torch.distributed.pipelining.schedules import ( _PipelineSchedule, _PipelineScheduleRuntime, @@ -119,17 +121,18 @@ def build_pipeline_schedule( f"of stages ({num_total_stages}) which may result in a bubble in the pipeline." ) - # validate that the batch size is divisible by the number of microbatches otherwise we'll hang or error during training - if job_config.training.batch_size % n_microbatches != 0: - raise ValueError( - f"Batch size {job_config.training.batch_size} must be divisible by number of microbatches {n_microbatches}. " - "Update the config arguments for either batch_size or pipeline_parallel_microbatches." - ) + # determine microbatch chunking specification + if job_config.experimental.pipeline_parallel_batch_split_dim != 0: + dim = job_config.experimental.pipeline_parallel_batch_split_dim + args_chunk_spec = (TensorChunkSpec(dim),) + else: + args_chunk_spec = None schedule = schedule_class( stages if looped_schedule else stages[0], n_microbatches=n_microbatches, loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, ) logger.info( f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule} "