Skip to content

Commit c1e796b

Browse files
authored
[Flux] Enabling Flux specific CI (#1173)
## Context 1. Add 2 tests for flux specific changes, and setup CI: - Integration tests: Running basic, parallelism, checkpointing and other integration tests. - Unite tests: currently only one tests to test flux_dataloader(). 2. Create a small dataset (sample = 10) for offline CI test environment. 3. Downloaded t5 and clip encoder configuration (`config.json`) and saved in torchtitan/experiments/flux/tests/assets 4. For test purpose, we are using llama3's testing tokenizer (TikTokenizer) to perform end-to-end tests, without introducing multiple tokenizer files to run offline. ## Next step: - Adding generate image scripts (need basic parallelism), and add a test to generate images.
1 parent 0354f22 commit c1e796b

33 files changed

+709
-61
lines changed

.ci/docker/common/install_conda.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ install_python() {
3939
install_pip_dependencies() {
4040
pushd /opt/conda
4141
# Install all Python dependencies
42-
pip_install -r /opt/conda/dev-requirements.txt
42+
pip_install -r /opt/conda/requirements-dev.txt
4343
pip_install -r /opt/conda/requirements.txt
44+
pip_install -r /opt/conda/requirements-flux.txt
4445
popd
4546
}
4647

File renamed without changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
transformers>=4.51.1
22
einops
33
sentencepiece
4+
pillow

.ci/docker/ubuntu/Dockerfile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@ ARG MINICONDA_VERSION
2929
ARG PYTHON_VERSION
3030
ENV PYTHON_VERSION=$PYTHON_VERSION
3131
ENV PATH /opt/conda/envs/py_$PYTHON_VERSION/bin:/opt/conda/bin:$PATH
32-
COPY dev-requirements.txt /opt/conda/
32+
COPY requirements-dev.txt /opt/conda/
3333
COPY requirements.txt /opt/conda/
34+
COPY requirements-flux.txt /opt/conda/
3435
COPY conda-env-ci.txt /opt/conda/
3536
COPY ./common/install_conda.sh install_conda.sh
3637
COPY ./common/utils.sh utils.sh
37-
RUN bash ./install_conda.sh && rm install_conda.sh utils.sh /opt/conda/dev-requirements.txt /opt/conda/requirements.txt /opt/conda/conda-env-ci.txt
38+
RUN bash ./install_conda.sh && rm install_conda.sh utils.sh /opt/conda/requirements-dev.txt /opt/conda/requirements.txt /opt/conda/requirements-flux.txt /opt/conda/conda-env-ci.txt
3839

3940
USER ci-user
4041
CMD ["bash"]
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
name: Flux 8 GPU Integration Test
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
paths:
7+
- 'torchtitan/experiments/flux/**'
8+
pull_request:
9+
paths:
10+
- 'torchtitan/experiments/flux/**'
11+
schedule:
12+
# Runs every 6 hours
13+
- cron: '0 */6 * * *'
14+
concurrency:
15+
group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
16+
cancel-in-progress: true
17+
18+
defaults:
19+
run:
20+
shell: bash -l -eo pipefail {0}
21+
22+
jobs:
23+
build-test:
24+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
25+
with:
26+
runner: linux.g5.48xlarge.nvidia.gpu
27+
gpu-arch-type: cuda
28+
gpu-arch-version: "12.6"
29+
# This image is faster to clone than the default, but it lacks CC needed by triton
30+
# (1m25s vs 2m37s).
31+
docker-image: torchtitan-ubuntu-20.04-clang12
32+
repository: pytorch/torchtitan
33+
upload-artifact: outputs
34+
script: |
35+
set -eux
36+
37+
# The generic Linux job chooses to use base env, not the one setup by the image
38+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
39+
conda activate "${CONDA_ENV}"
40+
41+
pip config --user set global.progress_bar off
42+
43+
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
44+
45+
mkdir artifacts-to-be-uploaded
46+
python -m torchtitan.experiments.flux.tests.integration_tests artifacts-to-be-uploaded --ngpu 8
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
name: Flux Model CPU Unit Test
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
paths:
7+
- 'torchtitan/experiments/flux/**'
8+
pull_request:
9+
paths:
10+
- 'torchtitan/experiments/flux/**'
11+
12+
13+
concurrency:
14+
group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
15+
cancel-in-progress: true
16+
17+
jobs:
18+
build-test:
19+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
20+
with:
21+
docker-image: torchtitan-ubuntu-20.04-clang12
22+
repository: pytorch/torchtitan
23+
script: |
24+
set -eux
25+
26+
# The generic Linux job chooses to use base env, not the one setup by the image
27+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
28+
conda activate "${CONDA_ENV}"
29+
30+
pip config --user set global.progress_bar off
31+
32+
pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
33+
pytest torchtitan/experiments/flux/tests/unit_tests/ --cov=. --cov-report=xml --durations=20 -vv

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ wandb
1414

1515
torchtitan/datasets/**/*.model
1616
assets/**/*.model
17-
torchtitan/experiments/**/assets/*
17+
torchtitan/experiments/flux/assets/*
1818

1919
# temp files
2020
*.log

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ possible. Contributions should follow the [Contributing Guidelines](#contributin
44

55
### Setup
66
```
7-
pip install -r dev-requirements.txt
7+
pip install -r requirements-dev.txt
88
```
99

1010
### Pull Requests

dev-requirements.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.ci/docker/requirements-dev.txt

tests/integration_tests.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class OverrideDefinitions:
3131
test_descr: str = "default"
3232
test_name: str = "default"
3333
ngpu: int = 4
34-
model_flavor: str = "debugmodel"
3534

3635
def __repr__(self):
3736
return self.test_descr
@@ -495,7 +494,6 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
495494
# run_test supports sequence of tests.
496495
test_name = test_flavor.test_name
497496
dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}"
498-
model_flavor_arg = f"--model.flavor {test_flavor.model_flavor}"
499497
all_ranks = ",".join(map(str, range(test_flavor.ngpu)))
500498

501499
for idx, override_arg in enumerate(test_flavor.override_args):
@@ -508,7 +506,6 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
508506
"./scripts/estimate/run_memory_estimation.sh"
509507
)
510508
cmd += " " + dump_folder_arg
511-
cmd += " " + model_flavor_arg
512509
if override_arg:
513510
cmd += " " + " ".join(override_arg)
514511
logger.info(

torchtitan/components/checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -561,8 +561,8 @@ def _save_last_step(self, curr_step: int) -> None:
561561

562562
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
563563
# temporarily and we don't want to include it in the exported state_dict.
564-
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py#L348
565-
self.states.pop("freqs_cis")
564+
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
565+
self.states.pop("freqs_cis", None)
566566

567567
if self.export_dtype != torch.float32:
568568
self.states = {

torchtitan/experiments/flux/README.md

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
# FLUX model in torchtitan
2+
[![integration tests](https://github.com/pytorch/torchtitan/actions/workflows/flux_integration_test_8gpu.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/flux_integration_test_8gpu.yaml?query=branch%3Amain)
3+
24

35
## Overview
46
This directory contains the implementation of the [FLUX](https://github.com/black-forest-labs/flux/tree/main) model in torchtitan. In torchtitan, we showcase the pre-training process of text-to-image part of the FLUX model.
57

8+
## Prerequisites
9+
Install the required dependencies:
10+
```bash
11+
pip install -r requirements-flux.txt
12+
```
13+
614
## Usage
715
First, download the autoencoder model from HuggingFace with your own access token:
816
```bash
@@ -22,15 +30,31 @@ If you want to train with other model config, run the following command:
2230
CONFIG_FILE="./torchtitan/experiments/flux/train_configs/flux_schnell_model.toml" ./torchtitan/experiments/flux/run_train.sh
2331
```
2432

33+
## Running Tests
34+
35+
### Unit Tests
36+
To run the unit tests for the FLUX model, use the following command:
37+
```bash
38+
pytest -s torchtitan/experiments/flux/tests/
39+
```
40+
41+
### Integration Tests
42+
To run the integration tests for the FLUX model, use the following command:
43+
```bash
44+
python -m torchtitan.experiments.flux.tests.integration_tests <output_dir>
45+
```
46+
47+
2548
## Supported Features
2649
- Parallelism: The model supports FSDP, HSDP for training on multiple GPUs.
2750
- Activation checkpointing: The model uses activation checkpointing to reduce memory usage during training.
2851
- Distributed checkpointing and loading.
29-
- Notes on the current checkpointing implementation: TO keep the model wieghts are sharded the same way as checkpointing, we need to shard the model weights before saving the checkpoint. This is done by checking each module at the end of envaluation, and sharding the weights of the module if it is a FSDPModule.
52+
- Notes on the current checkpointing implementation: To keep the model wieghts are sharded the same way as checkpointing, we need to shard the model weights before saving the checkpoint. This is done by checking each module at the end of envaluation, and sharding the weights of the module if it is a FSDPModule.
53+
- CI for FLUX model. Supported periodically running integration tests on 8 GPUs, and unittests.
3054

3155

3256

3357
## TODO
3458
- [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
3559
- [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
36-
- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
60+
- [ ] Add `torch.compile` support

torchtitan/experiments/flux/dataset/flux_dataset.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@
2020
from torch.utils.data import IterableDataset
2121
from torchtitan.components.dataloader import ParallelAwareDataloader
2222

23+
from torchtitan.components.tokenizer import Tokenizer
2324
from torchtitan.config_manager import JobConfig
24-
from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
25+
from torchtitan.experiments.flux.dataset.tokenizer import (
26+
build_flux_tokenizer,
27+
FluxTokenizer,
28+
)
2529
from torchtitan.tools.logging import logger
2630

2731

@@ -115,6 +119,13 @@ class TextToImageDatasetConfig:
115119
loader=lambda path: load_dataset(path, split="train", streaming=True),
116120
data_processor=_cc12m_wds_data_processor,
117121
),
122+
"cc12m-test": TextToImageDatasetConfig(
123+
path="torchtitan/experiments/flux/tests/assets/cc12m_test",
124+
loader=lambda path: load_dataset(
125+
path, split="train", data_files={"train": "*.tar"}, streaming=True
126+
),
127+
data_processor=_cc12m_wds_data_processor,
128+
),
118129
}
119130

120131

@@ -150,8 +161,8 @@ def __init__(
150161
self,
151162
dataset_name: str,
152163
dataset_path: Optional[str],
153-
t5_tokenizer: FluxTokenizer,
154-
clip_tokenizer: FluxTokenizer,
164+
t5_tokenizer: Tokenizer,
165+
clip_tokenizer: Tokenizer,
155166
job_config: Optional[JobConfig] = None,
156167
dp_rank: int = 0,
157168
dp_world_size: int = 1,
@@ -243,6 +254,7 @@ def __iter__(self):
243254
self._sample_idx += 1
244255

245256
labels = sample_dict.pop("image")
257+
246258
yield sample_dict, labels
247259

248260
def load_state_dict(self, state_dict):
@@ -267,21 +279,13 @@ def build_flux_dataloader(
267279
dataset_path = job_config.training.dataset_path
268280
batch_size = job_config.training.batch_size
269281

270-
t5_encoder_name = job_config.encoder.t5_encoder
271-
clip_encoder_name = job_config.encoder.clip_encoder
272-
max_t5_encoding_len = job_config.encoder.max_t5_encoding_len
282+
t5_tokenizer, clip_tokenizer = build_flux_tokenizer(job_config)
273283

274284
ds = FluxDataset(
275285
dataset_name=dataset_name,
276286
dataset_path=dataset_path,
277-
t5_tokenizer=FluxTokenizer(
278-
t5_encoder_name,
279-
max_length=max_t5_encoding_len,
280-
),
281-
clip_tokenizer=FluxTokenizer(
282-
clip_encoder_name,
283-
max_length=77,
284-
), # fix max_length for CLIP
287+
t5_tokenizer=t5_tokenizer,
288+
clip_tokenizer=clip_tokenizer,
285289
job_config=job_config,
286290
dp_rank=dp_rank,
287291
dp_world_size=dp_world_size,

torchtitan/experiments/flux/dataset/tokenizer.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,54 @@
1010

1111
from typing import List
1212

13+
import torch
1314
from torchtitan.components.tokenizer import Tokenizer
15+
from torchtitan.config_manager import JobConfig
16+
from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer
1417
from transformers import CLIPTokenizer, T5Tokenizer
1518

1619

20+
class FluxTestTokenizer(Tokenizer):
21+
"""
22+
Flux Tokenizer for test purpose. This is a simple wrapper around the TikTokenizer,
23+
to make it has same interface as the T5 and CLIP tokenizer used for Flux.
24+
"""
25+
26+
def __init__(self, model_path: str = "t5-small", max_length: int = 77, **hf_kwargs):
27+
self.tiktokenizer = TikTokenizer(model_path, **hf_kwargs)
28+
self._max_length = max_length
29+
self.pad_id = 0
30+
31+
def _pad_and_chunk_tokens(
32+
self, tokens: List[int], max_length: int, pad_token: int
33+
) -> List[int]:
34+
# Pad the token sequence to max_length
35+
if len(tokens) < max_length:
36+
# If tokens are shorter than max_length, pad with pad_id or eos_id if pad_id is not defined
37+
padding = [pad_token] * (max_length - len(tokens))
38+
tokens = tokens + padding
39+
40+
# Chunk the token sequence to max_length
41+
if len(tokens) > max_length:
42+
tokens = tokens[:max_length]
43+
44+
return tokens
45+
46+
def encode(self, text: str) -> torch.Tensor:
47+
"""
48+
Use TikTokenizer to encode the text into tokens, and then pad and chunk the tokens to max_length.
49+
"""
50+
tokens = self.tiktokenizer.encode(text, bos=True, eos=True)
51+
tokens = self._pad_and_chunk_tokens(tokens, self._max_length, self.pad_id)
52+
return torch.tensor(tokens)
53+
54+
def decode(self, t: List[int]) -> str:
55+
"""
56+
Decode function. This function will not be called.
57+
"""
58+
return self.tiktokenizer.decode(t)
59+
60+
1761
class FluxTokenizer(Tokenizer):
1862
"""
1963
Tokenizing and encoding/decoding text using the T5 or Clip tokenizer.
@@ -42,7 +86,7 @@ def __init__(self, model_path: str = "t5-small", max_length: int = 77, **hf_kwar
4286
def encode(
4387
self,
4488
s: str,
45-
) -> List[int]:
89+
) -> torch.Tensor:
4690
"""
4791
Encode the prompt text into tokens.
4892
"""
@@ -62,3 +106,32 @@ def decode(self, t: List[int]) -> str:
62106
Decode function. This function will not be called.
63107
"""
64108
return self._tokenizer.decode(t)
109+
110+
111+
def build_flux_tokenizer(job_config: JobConfig) -> tuple[Tokenizer, Tokenizer]:
112+
"""
113+
Build the tokenizer for Flux.
114+
"""
115+
t5_tokenizer_path = job_config.encoder.t5_encoder
116+
clip_tokenzier_path = job_config.encoder.clip_encoder
117+
max_t5_encoding_len = job_config.encoder.max_t5_encoding_len
118+
119+
# NOTE: This tokenizer is used for offline CI and testing only, borrowed from llama3 tokenizer
120+
if job_config.training.test_mode:
121+
tokenizer_class = FluxTestTokenizer
122+
t5_tokenizer_path = clip_tokenzier_path = job_config.model.tokenizer_path
123+
else:
124+
tokenizer_class = FluxTokenizer
125+
126+
# T5 tokenzier will pad the token sequence to max_t5_encoding_len,
127+
# and CLIP tokenizer will pad the token sequence to 77 (fixed number).
128+
t5_tokenizer = tokenizer_class(
129+
t5_tokenizer_path,
130+
max_length=max_t5_encoding_len,
131+
)
132+
clip_tokenizer = tokenizer_class(
133+
clip_tokenzier_path,
134+
max_length=77,
135+
)
136+
137+
return t5_tokenizer, clip_tokenizer

0 commit comments

Comments
 (0)