Skip to content

Use torch cpu, async write to tensorboard, script to convert latents … #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
jax>=0.4.30
--extra-index-url https://download.pytorch.org/whl/cpu
jax==0.5.3
jaxlib>=0.4.30
grain-nightly==0.0.10
google-cloud-storage==2.17.0
absl-py
datasets
flax>=0.10.2
optax>=0.2.3
torch==2.5.1
torchvision==0.20.1
torch==2.6.0
torchvision>=0.20.1
ftfy
tensorboard>=2.17.0
tensorboardx==2.6.2.2
Expand Down
63 changes: 59 additions & 4 deletions src/maxdiffusion/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from maxdiffusion import multihost_dataloading

AUTOTUNE = tf.data.experimental.AUTOTUNE
AUTOTUNE = tf.data.AUTOTUNE


def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_count):
Expand All @@ -31,7 +31,7 @@ def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_cou
if shuffle:
tf_dataset = tf_dataset.shuffle(len(tf_dataset))
tf_dataset = tf_dataset.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE)
tf_dataset = tf_dataset.prefetch(AUTOTUNE)
tf_dataset = tf_dataset.repeat(-1)

return tf_dataset
Expand Down Expand Up @@ -74,6 +74,57 @@ def make_tf_iterator(
return train_iter


def make_cached_tfrecord_iterator(
config,
dataloading_host_index,
dataloading_host_count,
mesh,
global_batch_size,
):
"""
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
latents, input_ids, prompt_embeds, and text_embeds.
"""
feature_description = {
"pixel_values": tf.io.FixedLenFeature([], tf.string),
"input_ids": tf.io.FixedLenFeature([], tf.string),
"prompt_embeds": tf.io.FixedLenFeature([], tf.string),
"text_embeds": tf.io.FixedLenFeature([], tf.string),
}

def _parse_tfrecord_fn(example):
return tf.io.parse_single_example(example, feature_description)

def prepare_sample(features):
pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32)
input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32)
prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32)
text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32)

return {
"pixel_values": pixel_values,
"input_ids": input_ids,
"prompt_embeds": prompt_embeds,
"text_embeds": text_embeds
}

# This pipeline reads the sharded files and applies the parsing and preparation.
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
train_ds = (
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
.shuffle(global_batch_size * 10)
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
.repeat(-1)
.prefetch(AUTOTUNE)
)

# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
return train_iter

# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
def make_tfrecord_iterator(
config,
Expand All @@ -86,19 +137,23 @@ def make_tfrecord_iterator(
check out preparation script
maxdiffusion/pedagogical_examples/to_tfrecords.py
"""
if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location):
return make_cached_tfrecord_iterator(config, dataloading_host_index,
dataloading_host_count, mesh,
global_batch_size)
feature_description = {
"moments": tf.io.FixedLenFeature([], tf.string),
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
}

def _parse_tfrecord_fn(example):
return tf.io.parse_single_example(example, feature_description)

def prepare_sample(features):
moments = tf.io.parse_tensor(tnp.asarray(features["moments"]), out_type=tf.float32)
clip_embeddings = tf.io.parse_tensor(tnp.asarray(features["clip_embeddings"]), out_type=tf.float32)
return {"pixel_values": moments, "input_ids": clip_embeddings}

filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
train_ds = (
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import os
import argparse
import tensorflow as tf
from datasets import load_from_disk
import numpy as np

def _bytes_feature(value):
"""Returns a bytes_list from a serialized tensor."""
if not isinstance(value, tf.Tensor):
value = tf.constant(value)
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))

def create_4_feature_example(record):
"""Creates a tf.train.Example proto with all 4 pre-computed features."""
pixel_values = tf.io.serialize_tensor(record['pixel_values'])
input_ids = tf.io.serialize_tensor(record['input_ids'])
prompt_embeds = tf.io.serialize_tensor(record['prompt_embeds'])
text_embeds = tf.io.serialize_tensor(record['text_embeds'])

feature = {
"pixel_values": _bytes_feature(pixel_values),
"input_ids": _bytes_feature(input_ids),
"prompt_embeds": _bytes_feature(prompt_embeds),
"text_embeds": _bytes_feature(text_embeds)
}
return tf.train.Example(features=tf.train.Features(feature=feature))

def run(args):
"""Main processing function."""
# Load the cached dataset from the location specified in the arguments
print(f"Loading processed dataset from disk: {args.dataset_save_location}")
processed_ds = load_from_disk(args.dataset_save_location)
print("Dataset loaded successfully.")

# Get sharding and output directory from the arguments
tfrecords_dir = args.tfrecords_dir
num_shards = args.data_num_shards
os.makedirs(tfrecords_dir, exist_ok=True)

writers = [
tf.io.TFRecordWriter(os.path.join(tfrecords_dir, f"shard-{i:05d}-of-{num_shards:05d}.tfrecord"))
for i in range(num_shards)
]

print(f"Writing {len(processed_ds)} records into {num_shards} TFRecord shards...")

for i, record in enumerate(processed_ds):
# Create a new record with explicit casting for float types
casted_record = {
"pixel_values": np.float32(record['pixel_values']),
"input_ids": record['input_ids'], # This is already integer type
"prompt_embeds": np.float32(record['prompt_embeds']),
"text_embeds": np.float32(record['text_embeds'])
}

writer_index = i % num_shards
tf_example = create_4_feature_example(casted_record)
writers[writer_index].write(tf_example.SerializeToString())

for writer in writers:
writer.close()

print("TFRecord conversion complete.")


def main():
"""Parses command-line arguments and runs the conversion."""
parser = argparse.ArgumentParser(
description="Convert a cached Hugging Face dataset to sharded TFRecords."
)
parser.add_argument(
"--dataset_save_location",
type=str,
required=False,
default="/tmp/pokemon-gpt4-captions_xl",
help="Path to the cached dataset created by the training pipeline."
)
parser.add_argument(
"--tfrecords_dir",
type=str,
required=False,
default="/tmp/cached_pokemon_tfrecords_sharded",
help="Output directory to save the sharded TFRecord files."
)
parser.add_argument(
"--data_num_shards",
type=int,
default=128,
help="Number of shards to split the TFRecord dataset into."
)

args = parser.parse_args()
run(args)

if __name__ == "__main__":
main()
45 changes: 34 additions & 11 deletions src/maxdiffusion/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import numpy as np
import jax
import jax.numpy as jnp
import threading
import queue

from maxdiffusion import max_utils, max_logging

Expand Down Expand Up @@ -67,10 +69,28 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):
metrics["scalar"].update({"perf/per_device_tflops_per_sec": per_device_tflops / step_time_delta.total_seconds()})
metrics["scalar"].update({"learning/current_learning_rate": lr})


_metrics_queue = queue.Queue()
_buffered_step = None
_buffered_metrics = None

def _tensorboard_writer_worker(writer, config):
"""
A worker function that runs in a separate thread.
It waits for metrics to appear in the queue and writes them to TensorBoard.
"""
while True:
data = _metrics_queue.get()
if data is None:
break
metrics, step = data
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

if step % config.log_period == 0:
writer.flush()

def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config):
"""Entry point for all metrics writing in Train's Main.
Expand All @@ -81,15 +101,18 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
The logic is that this ensures that Jax is able to queues train_steps and we
don't block when turning "lazy" Jax arrays into real Python numbers.
"""
global _buffered_step, _buffered_metrics
global _buffered_step, _buffered_metrics, _metrics_queue

if metrics:
_metrics_queue.put((metrics, step))
if _buffered_metrics is not None:
if config.metrics_file:
max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)

if _buffered_step is None:
raise ValueError(f"When writing metrics, {_buffered_step=} was none")
write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config)

if config.metrics_file:
max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)

if config.gcs_metrics and jax.process_index() == 0:
running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics)
Expand All @@ -100,13 +123,6 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step

def write_metrics_to_tensorboard(writer, metrics, step, config):
"""Writes metrics to tensorboard"""
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

full_log = step % config.log_period == 0
if jax.process_index() == 0:
max_logging.log(
"completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format(
Expand All @@ -116,6 +132,13 @@ def write_metrics_to_tensorboard(writer, metrics, step, config):
float(metrics["scalar"]["learning/loss"]),
)
)
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

full_log = step % config.log_period == 0

if full_log and jax.process_index() == 0:
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/trainers/flux_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def start_training(self):
state_shardings = {}

# move params to accelerator
encoders_sharding = PositionalSharding(self.devices_array).replicate()
encoders_sharding = jax.NamedSharding(self.mesh, P(None))
partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding)
pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params)
pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params)
Expand Down
Loading
Loading