Description
[RFC] Controller for SPMD+MPMD
Background
Current work is being done to design a solution for making mark_sharding
first trace the model before it is loaded into devices (#9341). Together with Local SPMD, this should enable us to achieve SPMD+MPMD as per its RFC. One leftover question is which controller to leverage and how to do it. This RFC aims to provide an approach, and two examples of how SPMD+MPMD.
API Discussion
Before thinking about the specifics on the controller, I think it is important to quickly discuss the user interaction experience with SPMD+MPMD. Specifically, how to handle pipeline parallelism in the context of also doing gSPMD. I see two different approaches: (1) to hide some of that process behind a newly created API, or a new level of abstraction; (2) to leverage existing pipeline parallelism tooling.
I think there is a temptation to create something behind a new API to try to simplify the process as much as possible, and create an easy user experience. However, PyTorch already has strong tooling around pipeline parallelism. These tools see external use, and they themselves ease the process of handling multiple processes running different parts of the pipeline.
Rather than creating a new API standard, it is likely better to approach this from a pytorch angle from a “this is a pytorch backend, how do I do pipeline parallelism with pytorch”. Looking at that angle, it is better to support SPMD+MPMD in these pipeline parallelism APIs rather than to create a new API.
Approach
The general approach will be to:
- Trace model without loading it to devices
- Split model into individually executing modules
- Create processes to execute on split modules
- Have modules be executed by process that will be responsible for executing gSPMD
From an implementation perspective, the idea is that by allowing Local SPMD, and latent model initialization, APIs created to specialize on pipeline parallelism should be able to manage their individual processes.
PiPPy
PiPPy is the pipeline parallelism library created by pytorch. It has an overall tool set that might be convenient. For PiPPy, pipeline parallelism usually will usually take:
- Initializing a model without loading it to devices
- Creating a pipe through pipeline
a. At this step, aGraphModule
is created which contain the modules for each process to execute later - Initializing a process group (
dist.init_process_group
) - Creating
PipelineStage
s based on the pipe - Executing each pipeline stage
You can see a step by step in PiPPy’s read me, or a llama model example here.
Either way, this lets PiPPy to admin individual processes while each process executes gSPMD for the specific modules it was created with.
Ray
Ray is a cluster controller for python that has a lot of utility for scaling large applications, including AI. Ray does not have an explicit pipeline parallelism API, but it can achieve it by leveraging its actors.
- Leverage PiPPy pipeline to create a
GraphModule
- Leverage “GraphModule” to identify module splits
- Create Ray actors based on these graph modules
- Launch Ray actors, and wait for them to resolve
Ray will administer the different actor pod while each pod executes gSPMD for the specific modules it was created with.
A tale of two pipeline parallelism approaches
Currently PyTorchXLA does have a pipeline parallelism approach documented in https://github.com/pytorch/xla/tree/r2.7?tab=readme-ov-file. In its existing approach, each device is associated with a process. As the original SPMD+MPMD RFC highlighted, this is a flawed approach as we are unable to apply gSPMD when using pipeline parallelism. The endeavor here to allow gSPMD to run in pipeline parallel through PiPPy, Ray, and other APIs might cause some confusion as a duplication of functionality.
Given that, it is worth noting that after the SPMD+MPMD effort, we should reassess our existing pipeline parallelism methodology, and see if it is possible to deduplicate to the more pytorch approach suggested in the RFC.