Skip to content

Commit 77be55d

Browse files
authored
Unflatten traced module (#954)
## Description - Move tracer from `_export_to_torch_ir` to the official `torch.export.export` - Add unflatten utils (from torch/export/unflatten.py) to unflatten each stage module Purpose of this PR is to: - be composable with FSDP and TP, which requires structured FQNs like `a.b.c` to submodules to specify their policies. - be nice to DCP which would not like to see change of FQNs compared to original model. - retire use of `_export_to_torch_ir` per Export Team's plan. ## Test Added `test_transformer.py`. ``` class TransformerLike(torch.nn.Module): def __init__(self) -> None: super().__init__() self.layers = torch.nn.Sequential( *[ MLPModule(d_hid) for _ in range(n_layers) ] ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.layers(x) ``` We split the model into two stages. Each stages would preserve the `layers.<i>` structure as in the original model. ``` Stage 0: GraphModule( (layers): InterpreterModule( (0): InterpreterModule( (net1): InterpreterModule() (relu): InterpreterModule() (net2): InterpreterModule() ) (1): InterpreterModule( (net1): InterpreterModule() (relu): InterpreterModule() (net2): InterpreterModule() ) (2): InterpreterModule( (net1): InterpreterModule() (relu): InterpreterModule() (net2): InterpreterModule() ) (3): InterpreterModule( (net1): InterpreterModule() (relu): InterpreterModule() (net2): InterpreterModule() ) ) ) ``` ``` Stage 1: GraphModule( (layers): InterpreterModule( (4): InterpreterModule( (net1): InterpreterModule() (relu): InterpreterModule() (net2): InterpreterModule() ) (5): InterpreterModule( (net1): InterpreterModule() (relu): InterpreterModule() (net2): InterpreterModule() ) (6): InterpreterModule( (net1): InterpreterModule() (relu): InterpreterModule() (net2): InterpreterModule() ) (7): InterpreterModule( (net1): InterpreterModule() (relu): InterpreterModule() (net2): InterpreterModule() ) ) ) ``` Caveat: I temporarily disabled multi-use parameter support (aka. shared paramters or tied parameters). So some real examples may break. Will add the support back in next PR.
1 parent 7d6650a commit 77be55d

22 files changed

+1049
-252
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,5 @@ exclude =
2828
./torch/include,
2929
./torch/lib,
3030
./venv,
31-
./pippy/fx,
31+
./pippy/unflatten.py,
3232
*.pyi

.github/workflows/gpu_tests.yaml

+9-7
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ jobs:
4242
- name: Activate conda env
4343
run: conda activate test
4444
- name: Install dependencies
45-
run:
45+
run: |
4646
pip install numpy
47-
pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cu121
47+
pip install --pre -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cu121/torch_nightly.html
4848
- name: Install pippy
4949
run: python setup.py install
5050
- name: Run forward-only integration test
@@ -61,10 +61,12 @@ jobs:
6161
run: pip install transformers
6262
- name: Run GPT2
6363
run: torchrun --nproc-per-node 4 examples/huggingface/pippy_gpt2.py
64-
- name: Run T5
65-
run: torchrun --nproc-per-node 2 examples/huggingface/pippy_t5.py
66-
- name: Run BERT
67-
run: torchrun --nproc-per-node 4 examples/huggingface/pippy_bert.py
64+
- name: Test CPU init + GPU run
65+
run: torchrun --nproc-per-node 4 examples/cpu_init/gpt2_cpu_init.py
66+
# - name: Run T5
67+
# run: torchrun --nproc-per-node 2 examples/huggingface/pippy_t5.py
68+
# - name: Run BERT
69+
# run: torchrun --nproc-per-node 4 examples/huggingface/pippy_bert.py
6870

6971
backward_tests_4gpu:
7072
runs-on: linux.g5.12xlarge.nvidia.gpu
@@ -87,7 +89,7 @@ jobs:
8789
- name: Install dependencies
8890
run:
8991
pip install numpy
90-
pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cu121
92+
pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cu121/torch_nightly.html
9193
- name: Install pippy
9294
run: python setup.py install
9395
- name: Run forward-backward test

.github/workflows/pippy_tests.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,15 @@ jobs:
8989
if [ -f requirements.txt ]; then pip install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; fi
9090
- name: Install pippy
9191
run: "python setup.py install"
92+
# IR tests
9293
- name: Test forward pipe generation
9394
run: python test/test_pipe.py
9495
- name: Test backward pipe generation
9596
run: python test/test_pipe_bwd.py
97+
- name: Test unflatten
98+
run: python test/test_unflatten.py
99+
- name: Test Transformer
100+
run: python test/test_transformer.py
96101
- name: Test pipeline schedule
97102
run: python test/test_pipeline_schedule.py
98103
# - name: Run null_coalesce_accumulate integration test

check.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ RETVAL=0
4949

5050
if (( SKIP_FORMAT == 0 )); then
5151
echo; echo "Running format check ..."
52-
ufmt diff pippy/*.py pippy/hf/*.py test/*.py
52+
ufmt diff pippy/*.py test/*.py
5353
(( RETVAL |= $? ))
5454
fi
5555

format.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
DEFAULT_TARGETS=()
88
for f in $(git ls-files | grep '\.py$'); do
99
case "$f" in
10-
'pippy/fx/'*)
10+
'pippy/unflatten.py')
1111
# ignore
1212
;;
1313

0 commit comments

Comments
 (0)