Skip to content

Commit e640ee7

Browse files
committed
add HuggingFace streaming support in data input pipeline
1 parent 33bb598 commit e640ee7

16 files changed

+791
-299
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
jax>=0.4.30
22
jaxlib>=0.4.30
3+
grain-nightly
34
google-cloud-storage==2.17.0
45
absl-py
56
datasets

src/maxdiffusion/configs/base14.yml

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,21 @@ ici_tensor_parallelism: 1
125125
# Dataset
126126
# Replace with dataset path or train_data_dir. One has to be set.
127127
dataset_name: 'diffusers/pokemon-gpt4-captions'
128-
# saves transformed dataset of dataset_name.
128+
train_split: 'train'
129+
dataset_type: 'tf'
130+
cache_latents_text_encoder_outputs: True
131+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
132+
# only apply to small dataset that fits in memory
133+
# prepare image latents and text encoder outputs
134+
# Reduce memory consumption and reduce step time during training
135+
# transformed dataset is saved at dataset_save_location
129136
dataset_save_location: '/tmp/pokemon-gpt4-captions_sd15'
130137
train_data_dir: ''
131138
dataset_config_name: ''
132139
jax_cache_dir: ''
140+
hf_data_dir: ''
141+
hf_train_files: ''
142+
hf_access_token: ''
133143
image_column: 'image'
134144
caption_column: 'text'
135145
resolution: 512
@@ -145,11 +155,6 @@ enable_data_shuffling: True
145155
# checkpoint every number of samples, -1 means don't checkpoint.
146156
checkpoint_every: -1
147157

148-
# Prepare image latents and text encoder outputs
149-
# during dataset creation to reduce memory consumption.
150-
cache_latents_text_encoder_outputs: True
151-
152-
153158
# Training loop
154159
learning_rate: 1.e-7
155160
scale_lr: False
@@ -205,4 +210,4 @@ class_prompt: ''
205210
prior_loss_weight: 1.0
206211
num_class_images: 100
207212
# If true, set dataset_save_location.
208-
cache_dreambooth_dataset: False
213+
cache_dreambooth_dataset: False

src/maxdiffusion/configs/base21.yml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,21 @@ ici_tensor_parallelism: 1
127127
# Dataset
128128
# Replace with dataset path or train_data_dir. One has to be set.
129129
dataset_name: 'diffusers/pokemon-gpt4-captions'
130-
# saves transformed dataset of dataset_name.
130+
train_split: 'train'
131+
dataset_type: 'tf'
132+
cache_latents_text_encoder_outputs: True
133+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
134+
# only apply to small dataset that fits in memory
135+
# prepare image latents and text encoder outputs
136+
# Reduce memory consumption and reduce step time during training
137+
# transformed dataset is saved at dataset_save_location
131138
dataset_save_location: '/tmp/pokemon-gpt4-captions_sd21'
132139
train_data_dir: ''
133140
dataset_config_name: ''
134141
jax_cache_dir: ''
142+
hf_data_dir: ''
143+
hf_train_files: ''
144+
hf_access_token: ''
135145
image_column: 'image'
136146
caption_column: 'text'
137147
resolution: 768
@@ -201,4 +211,4 @@ class_prompt: ''
201211
prior_loss_weight: 1.0
202212
num_class_images: 100
203213
# If true, set dataset_save_location.
204-
cache_dreambooth_dataset: False
214+
cache_dreambooth_dataset: False

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,21 @@ ici_tensor_parallelism: 1
140140
# Dataset
141141
# Replace with dataset path or train_data_dir. One has to be set.
142142
dataset_name: 'diffusers/pokemon-gpt4-captions'
143-
# saves transformed dataset of dataset_name.
143+
train_split: 'train'
144+
dataset_type: 'tf'
145+
cache_latents_text_encoder_outputs: True
146+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
147+
# only apply to small dataset that fits in memory
148+
# prepare image latents and text encoder outputs
149+
# Reduce memory consumption and reduce step time during training
150+
# transformed dataset is saved at dataset_save_location
144151
dataset_save_location: '/tmp/pokemon-gpt4-captions'
145152
train_data_dir: ''
146153
dataset_config_name: ''
147154
jax_cache_dir: ''
155+
hf_data_dir: ''
156+
hf_train_files: ''
157+
hf_access_token: ''
148158
image_column: 'image'
149159
caption_column: 'text'
150160
resolution: 512
@@ -160,11 +170,6 @@ enable_data_shuffling: True
160170
# checkpoint every number of samples, -1 means don't checkpoint.
161171
checkpoint_every: -1
162172

163-
# Prepare image latents and text encoder outputs
164-
# during dataset creation to reduce memory consumption.
165-
cache_latents_text_encoder_outputs: True
166-
167-
168173
# Training loop
169174
learning_rate: 1.e-7
170175
scale_lr: False
@@ -218,4 +223,4 @@ class_prompt: ''
218223
prior_loss_weight: 1.0
219224
num_class_images: 100
220225
# If true, set dataset_save_location.
221-
cache_dreambooth_dataset: False
226+
cache_dreambooth_dataset: False

src/maxdiffusion/configs/base_xl.yml

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,21 @@ ici_tensor_parallelism: 1
128128
# Dataset
129129
# Replace with dataset path or train_data_dir. One has to be set.
130130
dataset_name: 'diffusers/pokemon-gpt4-captions'
131-
# saves transformed dataset of dataset_name.
131+
train_split: 'train'
132+
dataset_type: 'tf'
133+
cache_latents_text_encoder_outputs: True
134+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
135+
# only apply to small dataset that fits in memory
136+
# prepare image latents and text encoder outputs
137+
# Reduce memory consumption and reduce step time during training
138+
# transformed dataset is saved at dataset_save_location
132139
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
133140
train_data_dir: ''
134141
dataset_config_name: ''
135142
jax_cache_dir: ''
143+
hf_data_dir: ''
144+
hf_train_files: ''
145+
hf_access_token: ''
136146
image_column: 'image'
137147
caption_column: 'text'
138148
resolution: 1024
@@ -148,10 +158,6 @@ enable_data_shuffling: True
148158
# checkpoint every number of samples, -1 means don't checkpoint.
149159
checkpoint_every: -1
150160

151-
# Prepare image latents and text encoder outputs
152-
# during dataset creation to reduce memory consumption.
153-
cache_latents_text_encoder_outputs: True
154-
155161
# Training loop
156162
learning_rate: 4.e-7
157163
scale_lr: False
@@ -204,4 +210,4 @@ enable_mllog: False
204210
controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
205211
controlnet_from_pt: True
206212
controlnet_conditioning_scale: 0.5
207-
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
213+
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import warnings
18+
import datasets
19+
from datasets import load_dataset
20+
from datasets.distributed import split_dataset_by_node
21+
import grain.python as grain
22+
23+
from maxdiffusion import max_logging
24+
from maxdiffusion import multihost_dataloading
25+
26+
27+
def make_hf_streaming_iterator(
28+
config,
29+
dataloading_host_index,
30+
dataloading_host_count,
31+
mesh,
32+
global_batch_size,
33+
tokenize_fn=None,
34+
image_transforms_fn=None,
35+
):
36+
"""Streaming data from HF Hub or GCS buckect.
37+
No download regardless of config.cache_latents_text_encoder_outputs"""
38+
ds = load_dataset(
39+
config.dataset_name,
40+
split=config.train_split,
41+
data_dir=config.hf_data_dir,
42+
data_files=config.hf_train_files,
43+
streaming=True,
44+
token=config.hf_access_token,
45+
)
46+
47+
ds = ds.shuffle(seed=config.seed)
48+
ds = ds.select_columns([config.caption_column, config.image_column])
49+
50+
if tokenize_fn:
51+
ds = ds.map(
52+
function=tokenize_fn,
53+
batched=True,
54+
remove_columns=[config.caption_column],
55+
)
56+
57+
if image_transforms_fn:
58+
ds = ds.map(
59+
function=image_transforms_fn,
60+
batched=True,
61+
remove_columns=[config.image_column],
62+
)
63+
64+
ds = HFDataSource(
65+
ds,
66+
dataloading_host_index,
67+
dataloading_host_count,
68+
)
69+
dummy_index_sampler = grain.IndexSampler(
70+
num_records=len(ds),
71+
num_epochs=1,
72+
shard_options=grain.ShardOptions(
73+
shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=False
74+
),
75+
shuffle=False,
76+
seed=0,
77+
)
78+
operations = [grain.Batch(batch_size=global_batch_size // dataloading_host_count, drop_remainder=True)]
79+
dataloader = grain.DataLoader(
80+
data_source=ds,
81+
operations=operations,
82+
sampler=dummy_index_sampler,
83+
worker_count=1, # only supports one worker for now, more workers results in duplicated data
84+
worker_buffer_size=1,
85+
read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=64),
86+
)
87+
train_iter = multihost_dataloading.MultiHostDataLoadIterator(dataloader, mesh)
88+
return train_iter
89+
90+
91+
class HFDataSource(grain.RandomAccessDataSource):
92+
"""A class that makes HuggingFace IterableDataset a grain datasource without random access support"""
93+
94+
def __init__(
95+
self,
96+
dataset: datasets.IterableDataset,
97+
dataloading_host_index: int,
98+
dataloading_host_count: int,
99+
):
100+
self.dataset = dataset
101+
self.dataloading_host_count = dataloading_host_count
102+
self.dataloading_host_index = dataloading_host_index
103+
self.n_shards = dataset.n_shards
104+
self._check_shard_count()
105+
self.current_shard = dataloading_host_index
106+
self.dataset_shard = split_dataset_by_node(dataset, world_size=self.n_shards, rank=self.current_shard)
107+
self.data_iter = None
108+
109+
def _check_shard_count(self):
110+
if self.n_shards < self.dataloading_host_count:
111+
warnings.warn(
112+
f"WARNING: Inefficient dataloading. Your train or eval dataset contains {self.n_shards} shards, "
113+
"smaller than number of host loading data. This is known to lead to inefficient dataloading. "
114+
"see https://github.com/AI-Hypercomputer/maxdiffusion/blob/main/docs/data_README.md#best-practice"
115+
)
116+
self.n_shards = self.dataloading_host_count
117+
118+
def _update_shard(self):
119+
new_shard = (self.current_shard + self.dataloading_host_count) % self.n_shards
120+
max_logging.log(f"Updating host {self.dataloading_host_index} dataset from shard {self.current_shard} to {new_shard}")
121+
self.current_shard = new_shard
122+
self.dataset_shard = split_dataset_by_node(self.dataset, world_size=self.n_shards, rank=self.current_shard)
123+
self.data_iter = iter(self.dataset_shard)
124+
125+
def __len__(self):
126+
"""Return length of the HF dataset. Since HuggingFace IterableDataset does not have length,
127+
a fake length bigger than the dataset is returned"""
128+
return 10_000_000_000
129+
130+
def __getitem__(self, index):
131+
"""Since HuggingFace IterableDataset does not support random access by index.
132+
The next item in the iterator is returned."""
133+
if not self.data_iter:
134+
self.data_iter = iter(self.dataset_shard)
135+
136+
while True:
137+
try:
138+
data = next(self.data_iter)
139+
return data
140+
except StopIteration:
141+
self._update_shard()

0 commit comments

Comments
 (0)