-
Notifications
You must be signed in to change notification settings - Fork 386
[Flux] Add batched inference #1227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
929fd66
afabe1a
ee9adbf
56d5ece
e6ac626
eccbbe3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
A serene mountain landscape at sunset with a crystal clear lake reflecting the golden sky | ||
A futuristic cityscape with flying cars and neon lights illuminating the night sky | ||
A cozy cafe interior with steam rising from coffee cups and warm lighting | ||
A magical forest with glowing mushrooms and fireflies dancing between ancient trees | ||
A peaceful beach scene with turquoise waves and palm trees swaying in the breeze | ||
A steampunk-inspired mechanical dragon soaring through clouds | ||
A mystical library with floating books and magical artifacts | ||
A Japanese garden in spring with cherry blossoms falling gently | ||
A space station orbiting a colorful nebula | ||
A medieval castle on a hilltop during a dramatic thunderstorm | ||
A underwater scene with bioluminescent creatures and coral reefs | ||
A desert oasis with a majestic palace and palm trees | ||
A cyberpunk street market with holographic signs and diverse crowds | ||
A cozy winter cabin surrounded by snow-covered pine trees | ||
A fantasy tavern filled with unique characters and magical atmosphere | ||
A tropical rainforest with exotic birds and waterfalls | ||
A steampunk airship navigating through storm clouds | ||
A peaceful zen garden with a traditional Japanese tea house | ||
A magical potion shop with bubbling cauldrons and mysterious ingredients | ||
A futuristic space colony on Mars with domed habitats | ||
A mystical temple hidden in the clouds | ||
A vintage train station with steam locomotives and period architecture | ||
A magical bakery with floating pastries and enchanted ingredients | ||
A peaceful countryside scene with rolling hills and a rustic farmhouse | ||
A underwater city with advanced technology and marine life | ||
A fantasy marketplace with magical creatures and exotic goods | ||
A peaceful meditation garden with lotus flowers and koi ponds | ||
A steampunk laboratory with intricate machinery and glowing elements | ||
A magical treehouse village connected by rope bridges | ||
A peaceful mountain monastery with prayer flags in the wind | ||
A futuristic greenhouse with exotic plants and advanced technology | ||
A mystical crystal cave with glowing formations and underground streams |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#!/usr/bin/bash | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
set -ex | ||
|
||
# use envs as local overrides for convenience | ||
# e.g. | ||
# LOG_RANK=0,1 NGPU=4 ./torchtitan/experiments/flux/run_inference.sh | ||
|
||
if [ -z "${JOB_FOLDER}" ]; then | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we default to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would argue making it explicitly required makes using the script clearer and less error prone, at the cost of some user friendliness. But I understand that point as well. I'll make the change |
||
echo "Error: JOB_FOLDER environment variable with path to model to load must be set" | ||
exit 1 | ||
fi | ||
|
||
NGPU=${NGPU:-"8"} | ||
export LOG_RANK=${LOG_RANK:-0} | ||
OUTPUT_DIR=${OUTPUT_DIR:-"inference_results"} | ||
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/flux/train_configs/debug_model.toml"} | ||
PROMPTS_FILE=${PROMPTS_FILE:-"torchtitan/experiments/flux/prompts.txt"} | ||
overrides="" | ||
if [ $# -ne 0 ]; then | ||
overrides="$*" | ||
fi | ||
|
||
|
||
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ | ||
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ | ||
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ | ||
-m torchtitan.experiments.flux.scripts.infer --job.config_file=${CONFIG_FILE} \ | ||
--inference.save_path=${OUTPUT_DIR} --job.dump_folder=${JOB_FOLDER} \ | ||
--inference.prompts_path=${PROMPTS_FILE} --checkpoint.enable_checkpoint \ | ||
--checkpoint.exclude_from_loading=lr_scheduler,dataloader,optimizer $overrides |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -71,12 +71,38 @@ def get_schedule( | |||
# ---------------------------------------- | ||||
|
||||
|
||||
def generate_empty_batch( | ||||
num_images: int, | ||||
device: torch.device, | ||||
dtype: torch.dtype, | ||||
clip_tokenizer: Tokenizer, | ||||
t5_tokenizer: Tokenizer, | ||||
clip_encoder: FluxEmbedder, | ||||
t5_encoder: FluxEmbedder, | ||||
): | ||||
empty_clip_tokens = clip_tokenizer.encode("") | ||||
empty_t5_tokens = t5_tokenizer.encode("") | ||||
empty_clip_tokens = empty_clip_tokens.repeat(num_images, 1) | ||||
empty_t5_tokens = empty_t5_tokens.repeat(num_images, 1) | ||||
return preprocess_data( | ||||
device=device, | ||||
dtype=dtype, | ||||
autoencoder=None, | ||||
clip_encoder=clip_encoder, | ||||
t5_encoder=t5_encoder, | ||||
batch={ | ||||
"clip_tokens": empty_clip_tokens, | ||||
"t5_tokens": empty_t5_tokens, | ||||
}, | ||||
) | ||||
|
||||
|
||||
def generate_image( | ||||
device: torch.device, | ||||
dtype: torch.dtype, | ||||
job_config: JobConfig, | ||||
model: FluxModel, | ||||
prompt: str, | ||||
prompt: str | list[str], | ||||
autoencoder: AutoEncoder, | ||||
t5_tokenizer: Tokenizer, | ||||
clip_tokenizer: Tokenizer, | ||||
|
@@ -89,19 +115,25 @@ def generate_image( | |||
Since we will always use the local random seed on this rank, we don't need to pass in the seed again. | ||||
""" | ||||
|
||||
if isinstance(prompt, str): | ||||
prompt = [prompt] | ||||
|
||||
# allow for packing and conversion to latent space. Use the same resolution as training time. | ||||
img_height = 16 * (job_config.training.img_size // 16) | ||||
img_width = 16 * (job_config.training.img_size // 16) | ||||
|
||||
enable_classifer_free_guidance = job_config.eval.enable_classifer_free_guidance | ||||
|
||||
# Tokenize the prompt. Unsqueeze to add a batch dimension. | ||||
clip_tokens = clip_tokenizer.encode(prompt).unsqueeze(0) | ||||
t5_tokens = t5_tokenizer.encode(prompt).unsqueeze(0) | ||||
clip_tokens = clip_tokenizer.encode(prompt) | ||||
t5_tokens = t5_tokenizer.encode(prompt) | ||||
if len(prompt) == 1: | ||||
clip_tokens = clip_tokens.unsqueeze(0) | ||||
t5_tokens = t5_tokens.unsqueeze(0) | ||||
|
||||
batch = preprocess_data( | ||||
device=device, | ||||
dtype=torch.bfloat16, | ||||
dtype=dtype, | ||||
autoencoder=None, | ||||
clip_encoder=clip_encoder, | ||||
t5_encoder=t5_encoder, | ||||
|
@@ -112,18 +144,14 @@ def generate_image( | |||
) | ||||
|
||||
if enable_classifer_free_guidance: | ||||
empty_clip_tokens = clip_tokenizer.encode("").unsqueeze(0) | ||||
empty_t5_tokens = t5_tokenizer.encode("").unsqueeze(0) | ||||
empty_batch = preprocess_data( | ||||
empty_batch = generate_empty_batch( | ||||
num_images=len(prompt), | ||||
device=device, | ||||
dtype=torch.bfloat16, | ||||
autoencoder=None, | ||||
dtype=dtype, | ||||
clip_tokenizer=clip_tokenizer, | ||||
t5_tokenizer=t5_tokenizer, | ||||
clip_encoder=clip_encoder, | ||||
t5_encoder=t5_encoder, | ||||
batch={ | ||||
"clip_tokens": empty_clip_tokens, | ||||
"t5_tokens": empty_t5_tokens, | ||||
}, | ||||
) | ||||
|
||||
img = denoise( | ||||
|
@@ -145,7 +173,7 @@ def generate_image( | |||
classifier_free_guidance_scale=job_config.eval.classifier_free_guidance_scale, | ||||
) | ||||
|
||||
img = autoencoder.decode(img) | ||||
img = autoencoder.decode(img.to(dtype)) | ||||
return img | ||||
|
||||
|
||||
|
@@ -177,9 +205,9 @@ def denoise( | |||
# create positional encodings | ||||
POSITION_DIM = 3 | ||||
latent_pos_enc = create_position_encoding_for_latents( | ||||
bsz, latent_height, latent_width, POSITION_DIM | ||||
1, latent_height, latent_width, POSITION_DIM | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QQ: Why we change the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So theres 2 parts to this. The reason we can do this, is because for these particular tensors where im using 1 on the first dimension, they are the same for all samples. so we just want to repeat them for all of them. Due to torch broadcasting, whenever this is used in an operation with another tensor, this dimension will be expanded to match whatever is necessary from the other tensor (basically torch will automatically make this whatever the batch size is) The reason we want to do this is twofold.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah the reasonning makes sense to me. It feels to me that, if the result of
|
||||
).to(latents) | ||||
text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents) | ||||
text_pos_enc = torch.zeros(1, t5_encodings.shape[1], POSITION_DIM).to(latents) | ||||
|
||||
if enable_classifer_free_guidance: | ||||
latents = torch.cat([latents, latents], dim=0) | ||||
|
@@ -191,7 +219,7 @@ def denoise( | |||
|
||||
# this is ignored for schnell | ||||
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): | ||||
t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device) | ||||
t_vec = torch.full((1,), t_curr, dtype=dtype, device=device) | ||||
pred = model( | ||||
img=latents, | ||||
img_ids=latent_pos_enc, | ||||
|
@@ -203,9 +231,12 @@ def denoise( | |||
if enable_classifer_free_guidance: | ||||
pred_u, pred_c = pred.chunk(2) | ||||
pred = pred_u + classifier_free_guidance_scale * (pred_c - pred_u) | ||||
|
||||
pred = pred.repeat(2, 1, 1) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And QQ: Why we need to repeat the first dimension of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My logic is as follows: Previously, since we were dealing with just 1 input, we didnt have to do this, as pred would end up with a bsz of 1. In the case of classifier_free_guidance, However, now, we can support batch sizes > 1, so pred will end up with a batch size = 1/2 of the bsz of latents, which in general will not be 1. Thus, we cannot benefit from broadcasting anymore, and have to do this repeat manually ourselves. |
||||
latents = latents + (t_prev - t_curr) * pred | ||||
|
||||
if enable_classifer_free_guidance: | ||||
latents = latents.chunk(2)[1] | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this code, together with the |
||||
|
||||
# convert sequences of patches into img-like latents | ||||
latents = unpack_latents(latents, latent_height, latent_width) | ||||
|
||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Functionality-wise, this seems similar to |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import math | ||
import os | ||
from pathlib import Path | ||
|
||
import torch | ||
from einops import rearrange | ||
from PIL import ExifTags, Image | ||
from torch.distributed.elastic.multiprocessing.errors import record | ||
|
||
from torchtitan.config_manager import ConfigManager | ||
from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer | ||
from torchtitan.experiments.flux.sampling import generate_image | ||
from torchtitan.experiments.flux.train import FluxTrainer | ||
from torchtitan.tools.logging import init_logger, logger | ||
|
||
|
||
def torch_to_pil(x: torch.Tensor) -> Image.Image: | ||
x = x.clamp(-1, 1) | ||
x = rearrange(x, "c h w -> h w c") | ||
return Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) | ||
|
||
|
||
@record | ||
def inference( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not obvious to me why this is worth a standalone function -- can we just call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only difference here is handling the batching. This could be handled by a method in the trainer, but I previously had it that way and refactored it out after discussions in #1205 I think both are valid, and its a matter of a design decision for torchtitan There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Sounds OK to me. |
||
prompts: list[str], | ||
trainer: FluxTrainer, | ||
t5_tokenizer: FluxTokenizer, | ||
clip_tokenizer: FluxTokenizer, | ||
bs: int = 1, | ||
): | ||
""" | ||
Run inference on the Flux model. | ||
""" | ||
results = [] | ||
with torch.no_grad(): | ||
for i in range(0, len(prompts), bs): | ||
images = generate_image( | ||
device=trainer.device, | ||
dtype=trainer._dtype, | ||
job_config=trainer.job_config, | ||
model=trainer.model_parts[0], | ||
prompt=prompts[i : i + bs], | ||
autoencoder=trainer.autoencoder, | ||
t5_tokenizer=t5_tokenizer, | ||
clip_tokenizer=clip_tokenizer, | ||
t5_encoder=trainer.t5_encoder, | ||
clip_encoder=trainer.clip_encoder, | ||
) | ||
results.append(images.detach()) | ||
results = torch.cat(results, dim=0) | ||
return results | ||
|
||
|
||
if __name__ == "__main__": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's better to put the following logic in a separate function (like |
||
init_logger() | ||
config_manager = ConfigManager() | ||
config = config_manager.parse_args() | ||
trainer = FluxTrainer(config) | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
global_id = int(os.environ["RANK"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use |
||
original_prompts = open(config.inference.prompts_path).readlines() | ||
total_prompts = len(original_prompts) | ||
|
||
# Each process processes its shard | ||
prompts = original_prompts[global_id::world_size] | ||
|
||
trainer.checkpointer.load(step=config.checkpoint.load_step) | ||
t5_tokenizer = FluxTokenizer( | ||
config.encoder.t5_encoder, | ||
max_length=config.encoder.max_t5_encoding_len, | ||
) | ||
clip_tokenizer = FluxTokenizer(config.encoder.clip_encoder, max_length=77) | ||
|
||
if global_id == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not required, the logging information should be controlled by TorchRun configuration, which TorchTitan run scripts default to rank 0 only. |
||
logger.info("Starting inference...") | ||
|
||
if prompts: | ||
images = inference( | ||
prompts, | ||
trainer, | ||
t5_tokenizer, | ||
clip_tokenizer, | ||
bs=config.inference.batch_size, | ||
) | ||
# pad the outputs to make sure all ranks have the same number of images for the gather step | ||
images = torch.cat( | ||
[ | ||
images, | ||
torch.zeros( | ||
math.ceil(total_prompts / world_size) - images.shape[0], | ||
3, | ||
256, | ||
256, | ||
Comment on lines
+98
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make these magic numbers constant variables for the readability purpose? |
||
device=trainer.device, | ||
), | ||
] | ||
) | ||
else: | ||
# if there are not enough prompts for all ranks, pad with empty tensors | ||
images = torch.zeros( | ||
math.ceil(total_prompts / world_size), 3, 256, 256, device=trainer.device | ||
) | ||
|
||
# Create a list of tensors to gather results | ||
gathered_images = [ | ||
torch.zeros_like(images, device=trainer.device) for _ in range(world_size) | ||
] | ||
|
||
# Gather images from all processes | ||
torch.distributed.all_gather(gathered_images, images) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be |
||
|
||
# re-order the images to match the original ordering of prompts | ||
if global_id == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another good motivation of making the logic in |
||
all_images = torch.zeros( | ||
size=[total_prompts, 3, 256, 256], | ||
dtype=torch.float32, | ||
device=trainer.device, | ||
) | ||
for in_rank_index in range(math.ceil(total_prompts / world_size)): | ||
for rank_index in range(world_size): | ||
global_idx = rank_index + in_rank_index * world_size | ||
if global_idx >= total_prompts: | ||
break | ||
all_images[global_idx] = gathered_images[rank_index][in_rank_index] | ||
logger.info("Inference done") | ||
|
||
pil_images = [torch_to_pil(img) for img in all_images] | ||
if config.inference.save_path: | ||
path = Path(config.job.dump_folder, config.inference.save_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we reuse the |
||
path.mkdir(parents=True, exist_ok=True) | ||
for i, img in enumerate(pil_images): | ||
exif_data = Image.Exif() | ||
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" | ||
exif_data[ExifTags.Base.Make] = "Black Forest Labs" | ||
exif_data[ExifTags.Base.Model] = "Schnell" | ||
exif_data[ExifTags.Base.ImageDescription] = original_prompts[i] | ||
img.save( | ||
path / f"img_{i}.png", exif=exif_data, quality=95, subsampling=0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If eventually we are saving individual image files, why do we even perform gather / all-gather? We could save different images in the same folder from different ranks, just with unique names i.e. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thats true. The tricky part is placing all the tensors back in the same order so we can match them up with the prompts. Because of the padding involved its not super straight forward, but I'm sure there's a way to do it while having each rank write its own images. Just might take some more thought There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An easier way to bypass padding is to require the prompts file having length divisible by DP degree, or world size. Users of this script can manually add empty rows if needed. |
||
) | ||
torch.distributed.destroy_process_group() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should something like:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to group files a bit more logically, can we put
run_inference.sh
,prompts.txt
,infer.py
under theflux/inference
folder? We can leavesampling.py
outside as it's also used by the evaluation intrain.py