Skip to content

Commit 9b20fb3

Browse files
committed
add HuggingFace streaming support in data input pipeline
1 parent 33bb598 commit 9b20fb3

16 files changed

+784
-289
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
jax>=0.4.30
22
jaxlib>=0.4.30
3+
grain-nightly
34
google-cloud-storage==2.17.0
45
absl-py
56
datasets

src/maxdiffusion/configs/base14.yml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,20 @@ ici_tensor_parallelism: 1
125125
# Dataset
126126
# Replace with dataset path or train_data_dir. One has to be set.
127127
dataset_name: 'diffusers/pokemon-gpt4-captions'
128-
# saves transformed dataset of dataset_name.
128+
dataset_type: 'tf'
129+
cache_latents_text_encoder_outputs: True
130+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
131+
# only apply to small dataset that fits in memory
132+
# prepare image latents and text encoder outputs
133+
# Reduce memory consumption and reduce step time during training
134+
# transformed dataset is saved at dataset_save_location
129135
dataset_save_location: '/tmp/pokemon-gpt4-captions_sd15'
130136
train_data_dir: ''
131137
dataset_config_name: ''
132138
jax_cache_dir: ''
139+
hf_data_dir: ''
140+
hf_train_files: ''
141+
hf_access_token: ''
133142
image_column: 'image'
134143
caption_column: 'text'
135144
resolution: 512
@@ -145,11 +154,6 @@ enable_data_shuffling: True
145154
# checkpoint every number of samples, -1 means don't checkpoint.
146155
checkpoint_every: -1
147156

148-
# Prepare image latents and text encoder outputs
149-
# during dataset creation to reduce memory consumption.
150-
cache_latents_text_encoder_outputs: True
151-
152-
153157
# Training loop
154158
learning_rate: 1.e-7
155159
scale_lr: False
@@ -205,4 +209,4 @@ class_prompt: ''
205209
prior_loss_weight: 1.0
206210
num_class_images: 100
207211
# If true, set dataset_save_location.
208-
cache_dreambooth_dataset: False
212+
cache_dreambooth_dataset: False

src/maxdiffusion/configs/base21.yml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,20 @@ ici_tensor_parallelism: 1
127127
# Dataset
128128
# Replace with dataset path or train_data_dir. One has to be set.
129129
dataset_name: 'diffusers/pokemon-gpt4-captions'
130-
# saves transformed dataset of dataset_name.
130+
dataset_type: 'tf'
131+
cache_latents_text_encoder_outputs: True
132+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
133+
# only apply to small dataset that fits in memory
134+
# prepare image latents and text encoder outputs
135+
# Reduce memory consumption and reduce step time during training
136+
# transformed dataset is saved at dataset_save_location
131137
dataset_save_location: '/tmp/pokemon-gpt4-captions_sd21'
132138
train_data_dir: ''
133139
dataset_config_name: ''
134140
jax_cache_dir: ''
141+
hf_data_dir: ''
142+
hf_train_files: ''
143+
hf_access_token: ''
135144
image_column: 'image'
136145
caption_column: 'text'
137146
resolution: 768
@@ -201,4 +210,4 @@ class_prompt: ''
201210
prior_loss_weight: 1.0
202211
num_class_images: 100
203212
# If true, set dataset_save_location.
204-
cache_dreambooth_dataset: False
213+
cache_dreambooth_dataset: False

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,20 @@ ici_tensor_parallelism: 1
140140
# Dataset
141141
# Replace with dataset path or train_data_dir. One has to be set.
142142
dataset_name: 'diffusers/pokemon-gpt4-captions'
143-
# saves transformed dataset of dataset_name.
143+
dataset_type: 'tf'
144+
cache_latents_text_encoder_outputs: True
145+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
146+
# only apply to small dataset that fits in memory
147+
# prepare image latents and text encoder outputs
148+
# Reduce memory consumption and reduce step time during training
149+
# transformed dataset is saved at dataset_save_location
144150
dataset_save_location: '/tmp/pokemon-gpt4-captions'
145151
train_data_dir: ''
146152
dataset_config_name: ''
147153
jax_cache_dir: ''
154+
hf_data_dir: ''
155+
hf_train_files: ''
156+
hf_access_token: ''
148157
image_column: 'image'
149158
caption_column: 'text'
150159
resolution: 512
@@ -160,11 +169,6 @@ enable_data_shuffling: True
160169
# checkpoint every number of samples, -1 means don't checkpoint.
161170
checkpoint_every: -1
162171

163-
# Prepare image latents and text encoder outputs
164-
# during dataset creation to reduce memory consumption.
165-
cache_latents_text_encoder_outputs: True
166-
167-
168172
# Training loop
169173
learning_rate: 1.e-7
170174
scale_lr: False
@@ -218,4 +222,4 @@ class_prompt: ''
218222
prior_loss_weight: 1.0
219223
num_class_images: 100
220224
# If true, set dataset_save_location.
221-
cache_dreambooth_dataset: False
225+
cache_dreambooth_dataset: False

src/maxdiffusion/configs/base_xl.yml

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,20 @@ ici_tensor_parallelism: 1
128128
# Dataset
129129
# Replace with dataset path or train_data_dir. One has to be set.
130130
dataset_name: 'diffusers/pokemon-gpt4-captions'
131-
# saves transformed dataset of dataset_name.
131+
dataset_type: 'tf'
132+
cache_latents_text_encoder_outputs: True
133+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
134+
# only apply to small dataset that fits in memory
135+
# prepare image latents and text encoder outputs
136+
# Reduce memory consumption and reduce step time during training
137+
# transformed dataset is saved at dataset_save_location
132138
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
133139
train_data_dir: ''
134140
dataset_config_name: ''
135141
jax_cache_dir: ''
142+
hf_data_dir: ''
143+
hf_train_files: ''
144+
hf_access_token: ''
136145
image_column: 'image'
137146
caption_column: 'text'
138147
resolution: 1024
@@ -148,10 +157,6 @@ enable_data_shuffling: True
148157
# checkpoint every number of samples, -1 means don't checkpoint.
149158
checkpoint_every: -1
150159

151-
# Prepare image latents and text encoder outputs
152-
# during dataset creation to reduce memory consumption.
153-
cache_latents_text_encoder_outputs: True
154-
155160
# Training loop
156161
learning_rate: 4.e-7
157162
scale_lr: False
@@ -204,4 +209,4 @@ enable_mllog: False
204209
controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
205210
controlnet_from_pt: True
206211
controlnet_conditioning_scale: 0.5
207-
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
212+
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import warnings
18+
import datasets
19+
from datasets import load_dataset
20+
from datasets.distributed import split_dataset_by_node
21+
import grain.python as grain
22+
23+
from maxdiffusion import max_logging
24+
from maxdiffusion import multihost_dataloading
25+
26+
27+
def make_hf_streaming_iterator(
28+
config,
29+
split,
30+
dataloading_host_index,
31+
dataloading_host_count,
32+
mesh,
33+
global_batch_size,
34+
tokenize_fn=None,
35+
image_transforms_fn=None,
36+
):
37+
"""Streaming data from HF Hub or GCS buckect.
38+
No download regardless of config.cache_latents_text_encoder_outputs"""
39+
ds = load_dataset(
40+
config.dataset_name,
41+
split=split,
42+
data_dir=config.hf_data_dir,
43+
data_files=config.hf_train_files,
44+
streaming=True,
45+
token=config.hf_access_token,
46+
)
47+
48+
ds = ds.shuffle(seed=config.seed)
49+
50+
if tokenize_fn:
51+
ds = ds.map(
52+
function=tokenize_fn,
53+
batched=True,
54+
remove_columns=[config.caption_column],
55+
)
56+
57+
if image_transforms_fn:
58+
ds = ds.map(
59+
function=image_transforms_fn,
60+
batched=True,
61+
remove_columns=[config.image_column],
62+
)
63+
64+
ds = HFDataSource(
65+
ds,
66+
dataloading_host_index,
67+
dataloading_host_count,
68+
)
69+
dummy_index_sampler = grain.IndexSampler(
70+
num_records=len(ds),
71+
num_epochs=1,
72+
shard_options=grain.ShardOptions(
73+
shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=False
74+
),
75+
shuffle=False,
76+
seed=0,
77+
)
78+
operations = [grain.Batch(batch_size=global_batch_size // dataloading_host_count, drop_remainder=True)]
79+
dataloader = grain.DataLoader(
80+
data_source=ds,
81+
operations=operations,
82+
sampler=dummy_index_sampler,
83+
worker_count=1, # only supports one worker for now, more workers results in duplicated data
84+
worker_buffer_size=1,
85+
read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=64),
86+
)
87+
train_iter = multihost_dataloading.MultiHostDataLoadIterator(dataloader, mesh)
88+
return train_iter
89+
90+
91+
class HFDataSource(grain.RandomAccessDataSource):
92+
"""A class that makes HuggingFace IterableDataset a grain datasource without random access support"""
93+
94+
def __init__(
95+
self,
96+
dataset: datasets.IterableDataset,
97+
dataloading_host_index: int,
98+
dataloading_host_count: int,
99+
):
100+
self.dataset = dataset
101+
self.dataloading_host_count = dataloading_host_count
102+
self.dataloading_host_index = dataloading_host_index
103+
self.n_shards = dataset.n_shards
104+
self._check_shard_count()
105+
self.current_shard = dataloading_host_index
106+
self.dataset_shard = split_dataset_by_node(dataset, world_size=self.n_shards, rank=self.current_shard)
107+
self.data_iter = None
108+
109+
def _check_shard_count(self):
110+
if self.n_shards < self.dataloading_host_count:
111+
warnings.warn(
112+
f"WARNING: Inefficient dataloading. Your train or eval dataset contains {self.n_shards} shards, "
113+
"smaller than number of host loading data. This is known to lead to inefficient dataloading. "
114+
"see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#multihost-dataloading-best-practice"
115+
)
116+
self.n_shards = self.dataloading_host_count
117+
118+
def _update_shard(self):
119+
new_shard = (self.current_shard + self.dataloading_host_count) % self.n_shards
120+
max_logging.log(f"Updating host {self.dataloading_host_index} dataset, was on shard {self.current_shard}")
121+
max_logging.log(f"New shard is {new_shard}")
122+
self.current_shard = new_shard
123+
self.dataset_shard = split_dataset_by_node(self.dataset, world_size=self.n_shards, rank=self.current_shard)
124+
self.data_iter = iter(self.dataset_shard)
125+
126+
def __len__(self):
127+
"""Return length of the HF dataset. Since HuggingFace IterableDataset does not have length,
128+
a fake length bigger than the dataset is returned"""
129+
return 10_000_000_000
130+
131+
def __getitem__(self, index):
132+
"""Since HuggingFace IterableDataset does not support random access by index.
133+
The next item in the iterator is returned."""
134+
if not self.data_iter:
135+
self.data_iter = iter(self.dataset_shard)
136+
137+
while True:
138+
try:
139+
data = next(self.data_iter)
140+
return data
141+
except StopIteration:
142+
self._update_shard()

0 commit comments

Comments
 (0)