Skip to content

Commit e3b9695

Browse files
committed
Add SDXL implementation for PyTorch/XLA training
1 parent 6edb774 commit e3b9695

File tree

5 files changed

+1064
-11
lines changed

5 files changed

+1064
-11
lines changed
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# SDXL fine-tuning using PyTorch/XLA
2+
3+
The `train_sdxl.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA.
4+
5+
It has been tested on v5p TPU versions.
6+
7+
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
8+
where we shard the input batches over the TPU devices.
9+
10+
As of 04-03-2025, these are some expected step times.
11+
12+
| accelerator | global batch size | step time (seconds) |
13+
| ----------- | ----------------- | --------- |
14+
| v5p-8 | 32 | 0.92 |
15+
| v5p-8 | 64 | 1.66 |
16+
17+
## Create TPU
18+
19+
To create a TPU on Google Cloud first set these environment variables:
20+
21+
```bash
22+
export TPU_NAME=<tpu-name>
23+
export PROJECT_ID=<project-id>
24+
export ZONE=<google-cloud-zone>
25+
export ACCELERATOR_TYPE=<accelerator type like v5p-8>
26+
export RUNTIME_VERSION=<runtime version like v2-alpha-tpuv5 for v5p>
27+
```
28+
29+
Then run the create TPU command:
30+
```bash
31+
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --project ${PROJECT_ID}
32+
--zone ${ZONE} --accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION}
33+
--reserved
34+
```
35+
36+
You can also use other ways to reserve TPUs like GKE or queued resources.
37+
38+
## Setup TPU environment
39+
40+
Install PyTorch and PyTorch/XLA nightly versions:
41+
```bash
42+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
43+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
44+
--command='
45+
pip3 install --pre torch==2.8.0.dev20250403+cpu torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
46+
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250403.cxx11-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
47+
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
48+
'
49+
```
50+
51+
Verify that PyTorch and PyTorch/XLA were installed correctly:
52+
53+
```bash
54+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
55+
--project ${PROJECT_ID} --zone ${ZONE} --worker=all \
56+
--command='python3 -c "import torch; import torch_xla;"'
57+
```
58+
59+
Install dependencies:
60+
```bash
61+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
62+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
63+
--command='
64+
git clone -b sdxl_xla https://github.com/pytorch-tpu/diffusers.git
65+
cd diffusers
66+
cd examples/research_projects/pytorch_xla/text_to_image/
67+
pip3 install -r requirements.txt
68+
pip3 install pillow --upgrade
69+
cd ../../..
70+
pip3 install .'
71+
```
72+
73+
## Run the training job
74+
75+
### Authenticate
76+
77+
Run the following command to authenticate your token.
78+
79+
```bash
80+
huggingface-cli login
81+
```
82+
83+
This script only trains the unet part of the network. The VAE and text encoder
84+
are fixed.
85+
86+
```bash
87+
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
88+
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
89+
--command='
90+
export XLA_DISABLE_FUNCTIONALIZATION=1
91+
export PROFILE_DIR=/tmp/
92+
export CACHE_DIR=/tmp/
93+
export DATASET_NAME=lambdalabs/naruto-blip-captions
94+
export PER_HOST_BATCH_SIZE=64 # This is known to work on TPU v5p
95+
export TRAIN_STEPS=50
96+
export PROFILE_START_STEP=10
97+
export OUTPUT_DIR=/tmp/trained-model/
98+
python diffusers/examples/research_projects/pytorch_xla/text_to_image/train_sdxl.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 --dataset_name=$DATASET_NAME --resolution=1024 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --measure_start_step=$PROFILE_START_STEP --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=5000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4 --xla_gradient_checkpointing'
99+
```
100+
101+
Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.
102+
103+
### Environment Envs Explained
104+
105+
* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.
106+
* `PROFILE_DIR`: Specify where to put the profiling results.
107+
* `CACHE_DIR`: Directory to store XLA compiled graphs for persistent caching.
108+
* `DATASET_NAME`: Dataset to train the model.
109+
* `PER_HOST_BATCH_SIZE`: Size of the batch to load per CPU host. For e.g. for a v5p-16 with 2 CPU hosts, the global batch size will be 2xPER_HOST_BATCH_SIZE. The input batch is sharded along the batch axis.
110+
* `TRAIN_STEPS`: Total number of training steps to run the training for.
111+
* `OUTPUT_DIR`: Directory to store the fine-tuned model.
112+
113+
## Run inference using the output model
114+
115+
To run inference using the output, you can simply load the model and pass it
116+
input prompts. The first pass will compile the graph and takes longer with the following passes running much faster.
117+
118+
```bash
119+
export CACHE_DIR=/tmp/
120+
```
121+
122+
```python
123+
import torch
124+
import os
125+
import sys
126+
import numpy as np
127+
128+
import torch_xla.core.xla_model as xm
129+
from time import time
130+
from diffusers import StableDiffusionPipeline
131+
import torch_xla.runtime as xr
132+
133+
CACHE_DIR = os.environ.get("CACHE_DIR", None)
134+
if CACHE_DIR:
135+
xr.initialize_cache(CACHE_DIR, readonly=False)
136+
137+
def main():
138+
device = xm.xla_device()
139+
model_path = "jffacevedo/pxla_trained_model"
140+
pipe = StableDiffusionPipeline.from_pretrained(
141+
model_path,
142+
torch_dtype=torch.bfloat16
143+
)
144+
pipe.to(device)
145+
prompt = ["A naruto with green eyes and red legs."]
146+
start = time()
147+
print("compiling...")
148+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
149+
print(f"compile time: {time() - start}")
150+
print("generate...")
151+
start = time()
152+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
153+
print(f"generation time (after compile) : {time() - start}")
154+
image.save("naruto.png")
155+
156+
if __name__ == '__main__':
157+
main()
158+
```
159+
160+
Expected Results:
161+
162+
```bash
163+
compiling...
164+
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [10:03<00:00, 20.10s/it]
165+
compile time: 720.656970500946
166+
generate...
167+
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 17.65it/s]
168+
generation time (after compile) : 1.8461642265319824

0 commit comments

Comments
 (0)