Skip to content

[Flux] Add batched inference #1227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions torchtitan/experiments/flux/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ If you want to train with other model config, run the following command:
CONFIG_FILE="./torchtitan/experiments/flux/train_configs/flux_schnell_model.toml" ./torchtitan/experiments/flux/run_train.sh
```

To perform inference with the model, run the following command:
```bash
JOB_FOLDER=outputs torchtitan/experiments/flux/run_inference.sh
```

## Running Tests

### Unit Tests
Expand Down
13 changes: 13 additions & 0 deletions torchtitan/experiments/flux/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ class Eval:
"""Directory to save image generated/sampled from the model"""


@dataclass
class Inference:
"""Inference configuration"""

save_path: str = "inference_results"
"""Path to save the inference results"""
prompts_path: str = "prompts.txt"
"""Path to file with newline separated prompts to generate images for"""
batch_size: int = 16
"""Batch size for inference"""


@dataclass
class JobConfig:
"""
Expand All @@ -56,3 +68,4 @@ class JobConfig:
training: Training = field(default_factory=Training)
encoder: Encoder = field(default_factory=Encoder)
eval: Eval = field(default_factory=Eval)
inference: Inference = field(default_factory=Inference)
32 changes: 32 additions & 0 deletions torchtitan/experiments/flux/prompts.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
A serene mountain landscape at sunset with a crystal clear lake reflecting the golden sky
A futuristic cityscape with flying cars and neon lights illuminating the night sky
A cozy cafe interior with steam rising from coffee cups and warm lighting
A magical forest with glowing mushrooms and fireflies dancing between ancient trees
A peaceful beach scene with turquoise waves and palm trees swaying in the breeze
A steampunk-inspired mechanical dragon soaring through clouds
A mystical library with floating books and magical artifacts
A Japanese garden in spring with cherry blossoms falling gently
A space station orbiting a colorful nebula
A medieval castle on a hilltop during a dramatic thunderstorm
A underwater scene with bioluminescent creatures and coral reefs
A desert oasis with a majestic palace and palm trees
A cyberpunk street market with holographic signs and diverse crowds
A cozy winter cabin surrounded by snow-covered pine trees
A fantasy tavern filled with unique characters and magical atmosphere
A tropical rainforest with exotic birds and waterfalls
A steampunk airship navigating through storm clouds
A peaceful zen garden with a traditional Japanese tea house
A magical potion shop with bubbling cauldrons and mysterious ingredients
A futuristic space colony on Mars with domed habitats
A mystical temple hidden in the clouds
A vintage train station with steam locomotives and period architecture
A magical bakery with floating pastries and enchanted ingredients
A peaceful countryside scene with rolling hills and a rustic farmhouse
A underwater city with advanced technology and marine life
A fantasy marketplace with magical creatures and exotic goods
A peaceful meditation garden with lotus flowers and koi ponds
A steampunk laboratory with intricate machinery and glowing elements
A magical treehouse village connected by rope bridges
A peaceful mountain monastery with prayer flags in the wind
A futuristic greenhouse with exotic plants and advanced technology
A mystical crystal cave with glowing formations and underground streams
37 changes: 37 additions & 0 deletions torchtitan/experiments/flux/run_inference.sh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to group files a bit more logically, can we put run_inference.sh, prompts.txt, infer.py under the flux/inference folder? We can leave sampling.py outside as it's also used by the evaluation in train.py

Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# All rights reserved.

# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -ex

# use envs as local overrides for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./torchtitan/experiments/flux/run_inference.sh

if [ -z "${JOB_FOLDER}" ]; then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we default to job.dump_folder's default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue making it explicitly required makes using the script clearer and less error prone, at the cost of some user friendliness. But I understand that point as well. I'll make the change

echo "Error: JOB_FOLDER environment variable with path to model to load must be set"
exit 1
fi

NGPU=${NGPU:-"8"}
export LOG_RANK=${LOG_RANK:-0}
OUTPUT_DIR=${OUTPUT_DIR:-"inference_results"}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/flux/train_configs/debug_model.toml"}
PROMPTS_FILE=${PROMPTS_FILE:-"torchtitan/experiments/flux/prompts.txt"}
overrides=""
if [ $# -ne 0 ]; then
overrides="$*"
fi


PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
-m torchtitan.experiments.flux.scripts.infer --job.config_file=${CONFIG_FILE} \
--inference.save_path=${OUTPUT_DIR} --job.dump_folder=${JOB_FOLDER} \
--inference.prompts_path=${PROMPTS_FILE} --checkpoint.enable_checkpoint \
--checkpoint.exclude_from_loading=lr_scheduler,dataloader,optimizer $overrides
67 changes: 49 additions & 18 deletions torchtitan/experiments/flux/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,38 @@ def get_schedule(
# ----------------------------------------


def generate_empty_batch(
num_images: int,
device: torch.device,
dtype: torch.dtype,
clip_tokenizer: Tokenizer,
t5_tokenizer: Tokenizer,
clip_encoder: FluxEmbedder,
t5_encoder: FluxEmbedder,
):
empty_clip_tokens = clip_tokenizer.encode("")
empty_t5_tokens = t5_tokenizer.encode("")
empty_clip_tokens = empty_clip_tokens.repeat(num_images, 1)
empty_t5_tokens = empty_t5_tokens.repeat(num_images, 1)
return preprocess_data(
device=device,
dtype=dtype,
autoencoder=None,
clip_encoder=clip_encoder,
t5_encoder=t5_encoder,
batch={
"clip_tokens": empty_clip_tokens,
"t5_tokens": empty_t5_tokens,
},
)


def generate_image(
device: torch.device,
dtype: torch.dtype,
job_config: JobConfig,
model: FluxModel,
prompt: str,
prompt: str | list[str],
autoencoder: AutoEncoder,
t5_tokenizer: Tokenizer,
clip_tokenizer: Tokenizer,
Expand All @@ -89,19 +115,25 @@ def generate_image(
Since we will always use the local random seed on this rank, we don't need to pass in the seed again.
"""

if isinstance(prompt, str):
prompt = [prompt]

# allow for packing and conversion to latent space. Use the same resolution as training time.
img_height = 16 * (job_config.training.img_size // 16)
img_width = 16 * (job_config.training.img_size // 16)

enable_classifer_free_guidance = job_config.eval.enable_classifer_free_guidance

# Tokenize the prompt. Unsqueeze to add a batch dimension.
clip_tokens = clip_tokenizer.encode(prompt).unsqueeze(0)
t5_tokens = t5_tokenizer.encode(prompt).unsqueeze(0)
clip_tokens = clip_tokenizer.encode(prompt)
t5_tokens = t5_tokenizer.encode(prompt)
if len(prompt) == 1:
clip_tokens = clip_tokens.unsqueeze(0)
t5_tokens = t5_tokens.unsqueeze(0)

batch = preprocess_data(
device=device,
dtype=torch.bfloat16,
dtype=dtype,
autoencoder=None,
clip_encoder=clip_encoder,
t5_encoder=t5_encoder,
Expand All @@ -112,18 +144,14 @@ def generate_image(
)

if enable_classifer_free_guidance:
empty_clip_tokens = clip_tokenizer.encode("").unsqueeze(0)
empty_t5_tokens = t5_tokenizer.encode("").unsqueeze(0)
empty_batch = preprocess_data(
empty_batch = generate_empty_batch(
num_images=len(prompt),
device=device,
dtype=torch.bfloat16,
autoencoder=None,
dtype=dtype,
clip_tokenizer=clip_tokenizer,
t5_tokenizer=t5_tokenizer,
clip_encoder=clip_encoder,
t5_encoder=t5_encoder,
batch={
"clip_tokens": empty_clip_tokens,
"t5_tokens": empty_t5_tokens,
},
)

img = denoise(
Expand All @@ -145,7 +173,7 @@ def generate_image(
classifier_free_guidance_scale=job_config.eval.classifier_free_guidance_scale,
)

img = autoencoder.decode(img)
img = autoencoder.decode(img.to(dtype))
return img


Expand Down Expand Up @@ -177,9 +205,9 @@ def denoise(
# create positional encodings
POSITION_DIM = 3
latent_pos_enc = create_position_encoding_for_latents(
bsz, latent_height, latent_width, POSITION_DIM
1, latent_height, latent_width, POSITION_DIM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Why we change the bsz to 1 here and later, as we are taking bsz prompts as input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So theres 2 parts to this. The reason we can do this, is because for these particular tensors where im using 1 on the first dimension, they are the same for all samples. so we just want to repeat them for all of them. Due to torch broadcasting, whenever this is used in an operation with another tensor, this dimension will be expanded to match whatever is necessary from the other tensor (basically torch will automatically make this whatever the batch size is)

The reason we want to do this is twofold.

  1. It does save some memory to not have to carry around all these repeated tensors, but to just allow torch to do broadcasting during operations instead
  2. If we dont do it, whenever we are doing classifier free guidance, we will have to manually double the size of the tensors. Like this, we dont have to worry about it, as it will just correctly broadcast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah the reasonning makes sense to me.

It feels to me that, if the result of latent_pos_enc is always identical for all samples in a batch, we probably should just remove the bsz as input arg and not worry about batch at all until its broadcast, instead of hardcoding bsz=1 at multiple places.

position_encoding = position_encoding.repeat(bsz, 1, 1)

).to(latents)
text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
text_pos_enc = torch.zeros(1, t5_encodings.shape[1], POSITION_DIM).to(latents)

if enable_classifer_free_guidance:
latents = torch.cat([latents, latents], dim=0)
Expand All @@ -191,7 +219,7 @@ def denoise(

# this is ignored for schnell
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
t_vec = torch.full((1,), t_curr, dtype=dtype, device=device)
pred = model(
img=latents,
img_ids=latent_pos_enc,
Expand All @@ -203,9 +231,12 @@ def denoise(
if enable_classifer_free_guidance:
pred_u, pred_c = pred.chunk(2)
pred = pred_u + classifier_free_guidance_scale * (pred_c - pred_u)

pred = pred.repeat(2, 1, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And QQ: Why we need to repeat the first dimension of pred?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My logic is as follows:

Previously, since we were dealing with just 1 input, we didnt have to do this, as pred would end up with a bsz of 1. In the case of classifier_free_guidance, latents would have a bsz of 2. Since pred has a bsz of 1, torch would broadcast this and it would work.

However, now, we can support batch sizes > 1, so pred will end up with a batch size = 1/2 of the bsz of latents, which in general will not be 1. Thus, we cannot benefit from broadcasting anymore, and have to do this repeat manually ourselves.

latents = latents + (t_prev - t_curr) * pred

if enable_classifer_free_guidance:
latents = latents.chunk(2)[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this code, together with the pred = pred.repeat(2, 1, 1) above, looks very obscure.
If they are necessary, please add adequate comments.


# convert sequences of patches into img-like latents
latents = unpack_latents(latents, latent_height, latent_width)

Expand Down
147 changes: 147 additions & 0 deletions torchtitan/experiments/flux/scripts/infer.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functionality-wise, this seems similar to torchtitan/experiments/flux/tests/test_generate_image.py, but with parallelized model.
I think we can make test_generate_image a unit test, if not removing it after this multi-gpu generation lands. @wwwjn

Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
import os
from pathlib import Path

import torch
from einops import rearrange
from PIL import ExifTags, Image
from torch.distributed.elastic.multiprocessing.errors import record

from torchtitan.config_manager import ConfigManager
from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
from torchtitan.experiments.flux.sampling import generate_image
from torchtitan.experiments.flux.train import FluxTrainer
from torchtitan.tools.logging import init_logger, logger


def torch_to_pil(x: torch.Tensor) -> Image.Image:
x = x.clamp(-1, 1)
x = rearrange(x, "c h w -> h w c")
return Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())


@record
def inference(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not obvious to me why this is worth a standalone function -- can we just call generate_image in the main script?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only difference here is handling the batching. This could be handled by a method in the trainer, but I previously had it that way and refactored it out after discussions in #1205

I think both are valid, and its a matter of a design decision for torchtitan

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Sounds OK to me.

prompts: list[str],
trainer: FluxTrainer,
t5_tokenizer: FluxTokenizer,
clip_tokenizer: FluxTokenizer,
bs: int = 1,
):
"""
Run inference on the Flux model.
"""
results = []
with torch.no_grad():
for i in range(0, len(prompts), bs):
images = generate_image(
device=trainer.device,
dtype=trainer._dtype,
job_config=trainer.job_config,
model=trainer.model_parts[0],
prompt=prompts[i : i + bs],
autoencoder=trainer.autoencoder,
t5_tokenizer=t5_tokenizer,
clip_tokenizer=clip_tokenizer,
t5_encoder=trainer.t5_encoder,
clip_encoder=trainer.clip_encoder,
)
results.append(images.detach())
results = torch.cat(results, dim=0)
return results


if __name__ == "__main__":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to put the following logic in a separate function (like main()). This will allow easier logic reuse (e.g., for unittests).

init_logger()
config_manager = ConfigManager()
config = config_manager.parse_args()
trainer = FluxTrainer(config)
world_size = int(os.environ["WORLD_SIZE"])
global_id = int(os.environ["RANK"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use rank to match the convention?

original_prompts = open(config.inference.prompts_path).readlines()
total_prompts = len(original_prompts)

# Each process processes its shard
prompts = original_prompts[global_id::world_size]

trainer.checkpointer.load(step=config.checkpoint.load_step)
t5_tokenizer = FluxTokenizer(
config.encoder.t5_encoder,
max_length=config.encoder.max_t5_encoding_len,
)
clip_tokenizer = FluxTokenizer(config.encoder.clip_encoder, max_length=77)

if global_id == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not required, the logging information should be controlled by TorchRun configuration, which TorchTitan run scripts default to rank 0 only.

logger.info("Starting inference...")

if prompts:
images = inference(
prompts,
trainer,
t5_tokenizer,
clip_tokenizer,
bs=config.inference.batch_size,
)
# pad the outputs to make sure all ranks have the same number of images for the gather step
images = torch.cat(
[
images,
torch.zeros(
math.ceil(total_prompts / world_size) - images.shape[0],
3,
256,
256,
Comment on lines +98 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make these magic numbers constant variables for the readability purpose?

device=trainer.device,
),
]
)
else:
# if there are not enough prompts for all ranks, pad with empty tensors
images = torch.zeros(
math.ceil(total_prompts / world_size), 3, 256, 256, device=trainer.device
)

# Create a list of tensors to gather results
gathered_images = [
torch.zeros_like(images, device=trainer.device) for _ in range(world_size)
]

# Gather images from all processes
torch.distributed.all_gather(gathered_images, images)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be gather() not all_gather() as you are not using the results on other ranks. Though I don't know if there will be any performance gains, gather() produces less total network traffic.


# re-order the images to match the original ordering of prompts
if global_id == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another good motivation of making the logic in main() -- you can do early return here to remove one indention.

all_images = torch.zeros(
size=[total_prompts, 3, 256, 256],
dtype=torch.float32,
device=trainer.device,
)
for in_rank_index in range(math.ceil(total_prompts / world_size)):
for rank_index in range(world_size):
global_idx = rank_index + in_rank_index * world_size
if global_idx >= total_prompts:
break
all_images[global_idx] = gathered_images[rank_index][in_rank_index]
logger.info("Inference done")

pil_images = [torch_to_pil(img) for img in all_images]
if config.inference.save_path:
path = Path(config.job.dump_folder, config.inference.save_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse the save_image() function in sampling.py here?

path.mkdir(parents=True, exist_ok=True)
for i, img in enumerate(pil_images):
exif_data = Image.Exif()
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
exif_data[ExifTags.Base.Model] = "Schnell"
exif_data[ExifTags.Base.ImageDescription] = original_prompts[i]
img.save(
path / f"img_{i}.png", exif=exif_data, quality=95, subsampling=0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If eventually we are saving individual image files, why do we even perform gather / all-gather? We could save different images in the same folder from different ranks, just with unique names i.e. rank_{i} in the .png name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats true. The tricky part is placing all the tensors back in the same order so we can match them up with the prompts. Because of the padding involved its not super straight forward, but I'm sure there's a way to do it while having each rank write its own images. Just might take some more thought

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An easier way to bypass padding is to require the prompts file having length divisible by DP degree, or world size. Users of this script can manually add empty rows if needed.

)
torch.distributed.destroy_process_group()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should something like:

try:
     main()
finally:
     if torch.distributed.is_initialized():
         torch.distributed.destroy_process_group()