diff --git a/examples/gemma/input_pipeline.py b/examples/gemma/input_pipeline.py index da9ae4733..82b3441e1 100644 --- a/examples/gemma/input_pipeline.py +++ b/examples/gemma/input_pipeline.py @@ -15,15 +15,11 @@ """Input pipeline for a LM1B dataset.""" import os -import typing +from typing import Any +import tokenizer import tensorflow as tf import tensorflow_datasets as tfds -import tokenizer -from clu import deterministic_data - -if typing.TYPE_CHECKING: - from train import TrainConfig AUTOTUNE = tf.data.experimental.AUTOTUNE Features = dict[str, tf.Tensor] @@ -58,9 +54,9 @@ def get_raw_dataset(dataset_name: str, split: str) -> tf.data.Dataset: def pack_dataset( - dataset: tf.data.Dataset, - key2length: int | dict[str, int], - keys: list[str] | None = None, + dataset: tf.data.Dataset, + key2length: int | dict[str, int], + keys: list[str] | None = None, ) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. @@ -107,8 +103,8 @@ def pack_dataset( for k in keys: if k not in shapes: raise ValueError( - 'Key %s not found in dataset. Available keys are %s' - % (k, shapes.keys()) + 'Key %s not found in dataset. Available keys are %s' + % (k, shapes.keys()) ) if not shapes[k].is_compatible_with(tf.TensorShape([None])): # type: ignore[wrong-arg-types] raise ValueError('Tensors to be packed must be one-dimensional.') @@ -122,14 +118,14 @@ def pack_dataset( # trim to length dataset = dataset.map( - lambda x: {k: x[k][: key2length[k]] for k in keys}, - num_parallel_calls=AUTOTUNE, + lambda x: {k: x[k][: key2length[k]] for k in keys}, + num_parallel_calls=AUTOTUNE, ) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) dataset = dataset.padded_batch( - batch_size, padded_shapes={k: [-1] for k in keys} + batch_size, padded_shapes={k: [-1] for k in keys} ) dataset = _pack_with_tf_ops(dataset, keys, key2length) @@ -141,7 +137,7 @@ def my_fn(x): def _pack_with_tf_ops( - dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int] + dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int] ) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. @@ -166,8 +162,8 @@ def write_packed_example(partial, outputs): new_outputs = {} for k in keys_etc: new_outputs[k] = outputs[k].write( - outputs[k].size(), - tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), + outputs[k].size(), + tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), ) return new_partial, new_outputs @@ -188,10 +184,10 @@ def map_fn(x): outputs = {} for k in keys: outputs[k] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] ) outputs[k + '_position'] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] ) def body_fn(i, partial, outputs): @@ -213,10 +209,10 @@ def body_fn(i, partial, outputs): one_example[k] = val for k in keys: can_append = tf.logical_and( - can_append, - tf.less_equal( - tf.size(partial[k]) + tf.size(one_example[k]), key2length[k] - ), + can_append, + tf.less_equal( + tf.size(partial[k]) + tf.size(one_example[k]), key2length[k] + ), ) def false_fn(): @@ -232,28 +228,28 @@ def true_fn(): new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + '_position'] = tf.concat( - [partial[k + '_position'], tf.range(new_seq_len)], 0 + [partial[k + '_position'], tf.range(new_seq_len)], 0 ) partial = new_partial return i + 1, partial, outputs # For loop over all examples in the batch. - i, partial, outputs = tf.while_loop( - cond=lambda *_: True, - body=body_fn, - loop_vars=(i, partial, outputs), - shape_invariants=( - tf.TensorShape([]), - {k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types] - {k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types] - ), - maximum_iterations=dynamic_batch_size, + _, partial, outputs = tf.while_loop( + cond=lambda *_: True, + body=body_fn, + loop_vars=(i, partial, outputs), + shape_invariants=( + tf.TensorShape([]), + {k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types] + {k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types] + ), + maximum_iterations=dynamic_batch_size, ) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: packed[k + '_segmentation'] = tf.cumsum( - tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1 + tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1 ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32) return packed @@ -263,8 +259,8 @@ def true_fn(): def shift_data_by_truncation(x): # https://github.com/AI-Hypercomputer/maxtext/blob/7fe1de75b3919c0fda00d23ad6cb29def9098362/MaxText/input_pipeline/_input_pipeline_utils.py#L53 - x["inputs"] = x["inputs"][:-1] - x["targets"] = x["targets"][1:] + x['inputs'] = x['inputs'][:-1] + x['targets'] = x['targets'][1:] return x @@ -272,16 +268,16 @@ def shift_data_by_truncation(x): # Main dataset prep routines. # ----------------------------------------------------------------------------- def preprocess_data( - dataset, - shuffle: bool, - num_epochs: int | None = 1, - pack_examples: bool = True, - shuffle_buffer_size: int = 1024, - max_length: int = 512, - batch_size: int = 256, - drop_remainder: bool = True, - prefetch_size: int = AUTOTUNE, - shift: bool = True, + dataset, + shuffle: bool, + num_epochs: int | None = 1, + pack_examples: bool = True, + shuffle_buffer_size: int = 1024, + max_length: int = 512, + batch_size: int = 256, + drop_remainder: bool = True, + prefetch_size: int = AUTOTUNE, + shift: bool = True, ): """Shuffle and batch/pack the given dataset.""" @@ -303,7 +299,9 @@ def filter_fn(x): # Shift inputs for teacher-forced training if shift: dataset = dataset.map( - shift_data_by_truncation, num_parallel_calls=AUTOTUNE, deterministic=True + shift_data_by_truncation, + num_parallel_calls=AUTOTUNE, + deterministic=True, ) if pack_examples: @@ -311,10 +309,10 @@ def filter_fn(x): dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) else: # simple (static-shape) padded batching dataset = dataset.padded_batch( - batch_size, - padded_shapes={'inputs': max_length, 'targets': max_length}, - padding_values={'inputs': 0, 'targets': 0}, - drop_remainder=drop_remainder, + batch_size, + padded_shapes={'inputs': max_length, 'targets': max_length}, + padding_values={'inputs': 0, 'targets': 0}, + drop_remainder=drop_remainder, ) if prefetch_size: @@ -324,10 +322,10 @@ def filter_fn(x): def get_datasets( - config: "TrainConfig", - *, - n_devices: int, - vocab_path: str | None = None, + config: Any, + *, + n_devices: int, + vocab_path: str | None = None, ): """Load and return dataset of batched examples for use during training.""" if vocab_path is None: @@ -343,16 +341,16 @@ def get_datasets( # Tokenize data. sp_processor = tokenizer.load_or_train_tokenizer( - train_data, - vocab_path=vocab_path, - vocab_size=config.vocab_size, - max_corpus_chars=config.max_corpus_chars, + train_data, + vocab_path=vocab_path, + vocab_size=config.vocab_size, + max_corpus_chars=config.max_corpus_chars, ) train_data = train_data.map( - tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE + tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE ) eval_data = eval_data.map( - tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE + tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE ) batch_size = config.per_device_batch_size * n_devices @@ -362,20 +360,20 @@ def get_datasets( eval_batch_size = batch_size train_ds = preprocess_data( - train_data, - shuffle=True, - num_epochs=None, - pack_examples=True, - batch_size=batch_size, - max_length=config.max_target_length, + train_data, + shuffle=True, + num_epochs=None, + pack_examples=True, + batch_size=batch_size, + max_length=config.max_target_length, ) eval_ds = preprocess_data( - eval_data, - shuffle=False, - pack_examples=False, - batch_size=eval_batch_size, - max_length=config.max_eval_target_length, + eval_data, + shuffle=False, + pack_examples=False, + batch_size=eval_batch_size, + max_length=config.max_eval_target_length, ) return train_ds, eval_ds, sp_processor diff --git a/examples/gemma/main.py b/examples/gemma/main.py index f4185e216..cd97f3f10 100644 --- a/examples/gemma/main.py +++ b/examples/gemma/main.py @@ -18,21 +18,24 @@ that can be easily tested and imported in Colab. """ -import jax -import tensorflow as tf -import train -from absl import app, flags, logging +from absl import app +from absl import flags +from absl import logging from clu import platform +import train +import jax from ml_collections import config_flags +import tensorflow as tf + FLAGS = flags.FLAGS flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( - 'config', - 'configs/default.py', - 'File path to the training hyperparameter configuration.', - lock_config=True, + 'config', + 'configs/default.py', + 'File path to the training hyperparameter configuration.', + lock_config=True, ) flags.mark_flags_as_required(['workdir']) @@ -51,11 +54,11 @@ def main(argv): # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( - f'process_index: {jax.process_index()}, ' - f'process_count: {jax.process_count()}' + f'process_index: {jax.process_index()}, ' + f'process_count: {jax.process_count()}' ) platform.work_unit().create_artifact( - platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' + platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) diff --git a/examples/gemma/tokenizer.py b/examples/gemma/tokenizer.py index 5dcaa437e..2c351cf29 100644 --- a/examples/gemma/tokenizer.py +++ b/examples/gemma/tokenizer.py @@ -14,28 +14,31 @@ """Provides op for tokenizing a dataset.""" +from collections.abc import Iterable import dataclasses import os import sys import tempfile import time from typing import Any -from collections.abc import Iterable +from absl import logging import jax import tensorflow as tf + +from sentencepiece import SentencePieceProcessor # pylint: disable=g-importing-member +from sentencepiece import SentencePieceTrainer # pylint: disable=g-importing-member + if sys.version_info < (3, 13): import tensorflow_text as tftxt -from absl import logging -from sentencepiece import SentencePieceTrainer, SentencePieceProcessor Features = dict[str, tf.Tensor] def _dump_chars_to_textfile( - dataset: tf.data.Dataset, - maxchars: int = int(1e7), - data_keys=('inputs', 'targets'), + dataset: tf.data.Dataset, + maxchars: int = int(1e7), + data_keys=('inputs', 'targets'), ) -> tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. @@ -50,7 +53,7 @@ def _dump_chars_to_textfile( char_count = 0 ds_iter = dataset.as_numpy_iterator() with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/ds_chars' + delete=False, prefix='/tmp/ds_chars' ) as outfp: while char_count < maxchars: example = next(ds_iter) @@ -62,18 +65,18 @@ def _dump_chars_to_textfile( def _train_sentencepiece( - dataset: tf.data.Dataset, - *, - vocab_size: int, - maxchars: int = int(1e7), - model_path: str, - model_type: str = 'unigram', - character_coverage: float = 1.0, - data_keys=('inputs', 'targets'), - pad_id: int = 0, - eos_id: int = 1, - bos_id: int = 2, - unk_id: int = 3, + dataset: tf.data.Dataset, + *, + vocab_size: int, + maxchars: int = int(1e7), + model_path: str, + model_type: str = 'unigram', + character_coverage: float = 1.0, + data_keys=('inputs', 'targets'), + pad_id: int = 0, + eos_id: int = 1, + bos_id: int = 2, + unk_id: int = 3, ): """Train SentencePiece tokenizer from subset of tf dataset. @@ -100,14 +103,13 @@ def _train_sentencepiece( else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) fname, _ = _dump_chars_to_textfile( - dataset, maxchars=maxchars, data_keys=data_keys + dataset, maxchars=maxchars, data_keys=data_keys ) with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/sp_tmp' + delete=False, prefix='/tmp/sp_tmp' ) as model_fp: pass # we just want a prefix'd tmp-filename - argstr = ' '.join( - [ + argstr = ' '.join([ f'--input={fname}', f'--vocab_size={vocab_size}', f'--character_coverage={character_coverage}', @@ -124,8 +126,7 @@ def _train_sentencepiece( f'--bos_id={bos_id}', f'--eos_id={eos_id}', f'--unk_id={unk_id}', - ] - ) + ]) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: # Use an intermediate filename that is renamed to the target name to address @@ -142,27 +143,27 @@ def _train_sentencepiece( def _load_sentencepiece_tokenizer( - model_path: str, - add_bos: bool = False, - add_eos: bool = True, - reverse: bool = False, + model_path: str, + add_bos: bool = False, + add_eos: bool = True, + reverse: bool = False, ): """Load a tf-text SentencePiece tokenizer from given model filepath.""" with tf.io.gfile.GFile(model_path, 'rb') as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse + model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse ) return sp_tokenizer def load_or_train_tokenizer( - dataset: tf.data.Dataset, - *, - vocab_path: str, - vocab_size: int, - max_corpus_chars: int, - data_keys: tuple[str, str] = ('inputs', 'targets'), + dataset: tf.data.Dataset, + *, + vocab_path: str, + vocab_size: int, + max_corpus_chars: int, + data_keys: tuple[str, str] = ('inputs', 'targets'), ): """Loads the tokenizer at `vocab_path` or trains a one from `dataset`.""" try: @@ -170,11 +171,11 @@ def load_or_train_tokenizer( except tf.errors.NotFoundError: logging.info('SentencePiece vocab not found, building one from data.') vocab_path = _train_sentencepiece( - dataset, - vocab_size=vocab_size, - maxchars=max_corpus_chars, - model_path=vocab_path, - data_keys=data_keys, + dataset, + vocab_size=vocab_size, + maxchars=max_corpus_chars, + model_path=vocab_path, + data_keys=data_keys, ) return _load_sentencepiece_tokenizer(vocab_path) @@ -192,5 +193,5 @@ def __call__(self, features: Features) -> Features: def load_sentencepiece_processor(vocab_path: str): spp = SentencePieceProcessor() - spp.load(vocab_path) + spp.Load(vocab_path) return spp diff --git a/examples/gemma/train.py b/examples/gemma/train.py index b5bc07745..b4da6e951 100644 --- a/examples/gemma/train.py +++ b/examples/gemma/train.py @@ -22,26 +22,25 @@ import dataclasses import os +from typing import Any +from absl import logging +from clu import metric_writers +from clu import periodic_actions +from flax import nnx import input_pipeline -import jax -import jax.numpy as jnp +import sampler as sampler_lib import tokenizer import transformer as transformer_lib +import utils +from flax.training import checkpoints +from flax.training import common_utils +import jax +from jax import random +import jax.numpy as jnp import numpy as np import optax -import sampler as sampler_lib import tensorflow as tf -import utils -from absl import logging -from clu import metric_writers, periodic_actions -from jax import random -from jax.sharding import Mesh, NamedSharding -from jax.sharding import PartitionSpec as P -from utils import TrainState - -from flax import nnx -from flax.training import checkpoints, common_utils @dataclasses.dataclass(unsafe_hash=True) @@ -53,13 +52,14 @@ class MeshRules: def __call__(self, *keys: str) -> tuple[str, ...]: return tuple( - getattr(self, key) if key is not None else None - for key in keys + getattr(self, key) if key is not None else None for key in keys ) @dataclasses.dataclass(unsafe_hash=True) class TrainConfig: + """Configuration for training a gemma model.""" + # Path to load or store sentencepiece vocab file. vocab_path: str | None # Vocabulary size if `vocab_path` is not given. @@ -107,10 +107,11 @@ class TrainConfig: # Gemma transformer name. # Possible values defined in transformer.TransformerConfig: - # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, ...) + # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, + # ...) transformer_name: str | None # or alternatively define the model using the dict of parameters - transformer_params: dict | None + transformer_params: dict[Any, Any] | None # Whether to save model checkpoints. save_checkpoints: bool @@ -157,8 +158,8 @@ def __post_init__(self): def rsqrt_schedule( - init_value: float, - shift: int = 0, + init_value: float, + shift: int = 0, ): """Applies a reverse square-root schedule. @@ -182,20 +183,20 @@ def schedule(count): def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): """Creates a rsqrt schedule with linear warmup.""" return optax.join_schedules( - [ - optax.linear_schedule( - init_value=0, - end_value=learning_rate, - transition_steps=warmup_steps, - ), - rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), - ], - boundaries=[warmup_steps], + [ + optax.linear_schedule( + init_value=0, + end_value=learning_rate, + transition_steps=warmup_steps, + ), + rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), + ], + boundaries=[warmup_steps], ) def compute_weighted_cross_entropy( - logits, targets, weights=None, label_smoothing=0.0 + logits, targets, weights=None, label_smoothing=0.0 ): """Compute weighted cross entropy and entropy for log probs and targets. @@ -211,18 +212,18 @@ def compute_weighted_cross_entropy( """ if logits.ndim != targets.ndim + 1: raise ValueError( - 'Incorrect shapes. Got shape %s logits and %s targets' - % (str(logits.shape), str(targets.shape)) + 'Incorrect shapes. Got shape %s logits and %s targets' + % (str(logits.shape), str(targets.shape)) ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( - confidence * jnp.log(confidence) - + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + confidence * jnp.log(confidence) + + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_targets = common_utils.onehot( - targets, vocab_size, on_value=confidence, off_value=low_confidence + targets, vocab_size, on_value=confidence, off_value=low_confidence ) loss = -jnp.sum(soft_targets * nnx.log_softmax(logits), axis=-1) @@ -249,8 +250,8 @@ def compute_weighted_accuracy(logits, targets, weights=None): """ if logits.ndim != targets.ndim + 1: raise ValueError( - 'Incorrect shapes. Got shape %s logits and %s targets' - % (str(logits.shape), str(targets.shape)) + 'Incorrect shapes. Got shape %s logits and %s targets' + % (str(logits.shape), str(targets.shape)) ) loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = np.prod(logits.shape[:-1]) @@ -264,13 +265,13 @@ def compute_weighted_accuracy(logits, targets, weights=None): def compute_metrics(logits, labels, weights, label_smoothing=0.0): """Compute summary metrics.""" loss, weight_sum = compute_weighted_cross_entropy( - logits, labels, weights, label_smoothing + logits, labels, weights, label_smoothing ) acc, _ = compute_weighted_accuracy(logits, labels, weights) metrics = { - 'loss': loss, - 'accuracy': acc, - 'denominator': weight_sum, + 'loss': loss, + 'accuracy': acc, + 'denominator': weight_sum, } return metrics @@ -280,10 +281,10 @@ def compute_metrics(logits, labels, weights, label_smoothing=0.0): def train_step( - state: TrainState, - batch, - learning_rate_fn, - label_smoothing=0.0, + state: utils.TrainState, + batch, + learning_rate_fn, + label_smoothing=0.0, ): """Perform a single training step.""" # X_position and X_segmentation are needed only when using "packed examples" @@ -293,16 +294,20 @@ def train_step( # like a normal, unpacked sequence example. train_keys = ['inputs', 'inputs_position', 'inputs_segmentation', 'targets'] (inputs, inputs_positions, inputs_segmentation, targets) = ( - batch.get(k, None) for k in train_keys + batch.get(k, None) for k in train_keys ) # TODO: this should be defined globally pad_id = 0 weights = jnp.where(inputs > pad_id, 1, 0).astype(jnp.float32) input_mask = inputs > pad_id - attention_mask = transformer_lib.make_causal_attn_mask(input_mask) # (B, L, L) + attention_mask = transformer_lib.make_causal_attn_mask( + input_mask + ) # (B, L, L) # inputs_segmentation: (B, L) - mask = inputs_segmentation[:, :, None] == inputs_segmentation[:, None, :] # (B, L, L) + mask = ( + inputs_segmentation[:, :, None] == inputs_segmentation[:, None, :] + ) # (B, L, L) attention_mask = jnp.logical_and(mask, attention_mask) def loss_fn(params): @@ -310,14 +315,14 @@ def loss_fn(params): module = nnx.merge(state.graphdef, params) logits, _ = module( - inputs, - positions=inputs_positions, - attention_mask=attention_mask, - cache=None, + inputs, + positions=inputs_positions, + attention_mask=attention_mask, + cache=None, ) loss, weight_sum = compute_weighted_cross_entropy( - logits, targets, weights, label_smoothing + logits, targets, weights, label_smoothing ) mean_loss = loss / weight_sum return mean_loss, logits @@ -334,10 +339,10 @@ def loss_fn(params): def eval_step( - params: nnx.State, - batch, - graphdef: nnx.GraphDef[transformer_lib.Transformer], - label_smoothing=0.0, + params: nnx.State, + batch, + graphdef: nnx.GraphDef[transformer_lib.Transformer], + label_smoothing=0.0, ): """Calculate evaluation metrics on a batch.""" inputs, targets = batch['inputs'], batch['targets'] @@ -351,21 +356,21 @@ def eval_step( module = nnx.merge(graphdef, params) logits, _ = module( - inputs, - positions=inputs_positions, - attention_mask=attention_mask, - cache=None, + inputs, + positions=inputs_positions, + attention_mask=attention_mask, + cache=None, ) return compute_metrics(logits, targets, weights, label_smoothing) def evaluate( - *, - jit_eval_step, - state: TrainState, - eval_ds: tf.data.Dataset, - num_eval_steps: int, + *, + jit_eval_step, + state: utils.TrainState, + eval_ds: tf.data.Dataset, + num_eval_steps: int, ): """Evaluate the target an return a dictionary with the metrics.""" logging.info('Gathering evaluation metrics.') @@ -379,8 +384,8 @@ def evaluate( eval_metrics_sums = jax.tree.map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree.map( - lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop - eval_metrics_sums, + lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop + eval_metrics_sums, ) return eval_summary @@ -406,7 +411,7 @@ def train_and_evaluate(config: TrainConfig, workdir: str): # --------------------------------------------------------------------------- logging.info('Initializing dataset.') train_ds, eval_ds, encoder = input_pipeline.get_datasets( - n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path + n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path ) train_iter = iter(train_ds) @@ -417,48 +422,48 @@ def train_and_evaluate(config: TrainConfig, workdir: str): # --------------------------------------------------------------------------- if config.transformer_name is not None: model_config = transformer_lib.TransformerConfig.from_version_name( - config.transformer_name, - num_embed=vocab_size, - dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, - axis_rules=config.axis_rules, + config.transformer_name, + num_embed=vocab_size, + dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, + axis_rules=config.axis_rules, ) else: assert config.transformer_params is not None model_config = transformer_lib.TransformerConfig.from_dict( - **config.transformer_params, - num_embed=vocab_size, - dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, - axis_rules=config.axis_rules, + **config.transformer_params, + num_embed=vocab_size, + dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, + axis_rules=config.axis_rules, ) # Mesh definition devices_array = utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) + mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) - rng, inference_rng = random.split(rng) + _, inference_rng = random.split(rng) def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): return transformer_lib.Transformer(config, rngs=nnx.Rngs(params=key)) learning_rate_fn = create_learning_rate_schedule( - learning_rate=config.learning_rate, warmup_steps=config.warmup_steps + learning_rate=config.learning_rate, warmup_steps=config.warmup_steps ) optimizer = optax.adamw( - learning_rate_fn, - b1=0.9, - b2=0.98, - eps=1e-9, - weight_decay=config.weight_decay, + learning_rate_fn, + b1=0.9, + b2=0.98, + eps=1e-9, + weight_decay=config.weight_decay, ) state, state_sharding = utils.setup_initial_state( - constructor, optimizer, model_config, init_rng, mesh + constructor, optimizer, model_config, init_rng, mesh ) - data_sharding = NamedSharding(mesh, P(config.data_sharding)) + data_sharding = jax.NamedSharding(mesh, jax.P(config.data_sharding)) if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. @@ -467,38 +472,38 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): start_step = int(state.step) writer = metric_writers.create_default_writer( - workdir, just_logging=jax.process_index() > 0 + workdir, just_logging=jax.process_index() > 0 ) if start_step == 0: writer.write_hparams(dataclasses.asdict(config)) # compile multidevice versions of train/eval/predict step fn. jit_train_step = jax.jit( - train_step, - in_shardings=( - state_sharding, - data_sharding, - ), # type: ignore - out_shardings=(state_sharding, None), # type: ignore - static_argnames=("learning_rate_fn", "label_smoothing"), - donate_argnums=0, + train_step, + in_shardings=( + state_sharding, + data_sharding, + ), # type: ignore + out_shardings=(state_sharding, None), # type: ignore + static_argnames=('learning_rate_fn', 'label_smoothing'), + donate_argnums=0, ) jit_eval_step = jax.jit( - eval_step, - in_shardings=( - state_sharding.params, - data_sharding, - ), # type: ignore - out_shardings=None, # type: ignore - static_argnames=("graphdef", "label_smoothing"), + eval_step, + in_shardings=( + state_sharding.params, + data_sharding, + ), # type: ignore + out_shardings=None, # type: ignore + static_argnames=('graphdef', 'label_smoothing'), ) vocab = tokenizer.load_sentencepiece_processor(vocab_path) - sampler = sampler_lib.Sampler( - transformer=nnx.merge(state.graphdef, state.params), - vocab=vocab, - cache_size=1024, + sampler = sampler_lib.Sampler( + transformer=nnx.merge(state.graphdef, state.params), + vocab=vocab, + cache_size=1024, ) # Main Train Loop @@ -509,12 +514,12 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( - num_train_steps=config.num_train_steps, writer=writer + num_train_steps=config.num_train_steps, writer=writer ) if jax.process_index() == 0: hooks += [ - report_progress, - periodic_actions.Profile(logdir=workdir, num_profile_steps=5), + report_progress, + periodic_actions.Profile(logdir=workdir, num_profile_steps=5), ] train_metrics = [] with metric_writers.ensure_flushes(writer): @@ -525,12 +530,12 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): with jax.profiler.StepTraceAnnotation('train', step_num=step): with report_progress.timed('data'): batch = next(train_iter) - batch = jax.tree.map(lambda x: jnp.asarray(x, device=data_sharding), batch) + batch = jax.tree.map( + lambda x: jnp.asarray(x, device=data_sharding), batch + ) with report_progress.timed('train_step'): - state, metrics = jit_train_step( - state, batch, learning_rate_fn, 0.0 - ) + state, metrics = jit_train_step(state, batch, learning_rate_fn, 0.0) train_metrics.append(metrics) # Quick indication that training is happening. @@ -541,14 +546,17 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): # Write batch loss and lr every step to TB # without overwhelming the stdout: if jax.process_index() == 0: - tb_writer = writer._writers[-1] + tb_writer = writer._writers[-1] # pylint: disable=protected-access lr = train_metrics[-1]['learning_rate'] train_batch_loss = train_metrics[-1]['loss'] denominator = train_metrics[-1]['denominator'] - tb_writer.write_scalars(step, { - "train_learning_rate": lr, - "train_loss": train_batch_loss / denominator, - }) + tb_writer.write_scalars( + step, + { + 'train_learning_rate': lr, + 'train_loss': train_batch_loss / denominator, + }, + ) # Periodic metric handling. if (step > 0 and step % config.eval_every_steps == 0) or is_last_step: @@ -569,33 +577,33 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): # update sampler's transformer state: sampler.transformer_state = state.params exemplars = sampler( - config.prompts, - total_generation_steps=config.num_predict_steps, - temperature=config.sampling_temperature, - top_p=config.sampling_top_p, - seed=inference_rng, - echo=True, + config.prompts, + total_generation_steps=config.num_predict_steps, + temperature=config.sampling_temperature, + top_p=config.sampling_top_p, + seed=inference_rng, + echo=True, ) - writer.write_texts(step, {'samples': exemplars.text}) + writer.write_texts(step, {'samples': exemplars.text[0]}) with report_progress.timed('eval'): eval_results = evaluate( - jit_eval_step=jit_eval_step, - state=state, - eval_ds=eval_ds, - num_eval_steps=config.num_eval_steps, + jit_eval_step=jit_eval_step, + state=state, + eval_ds=eval_ds, + num_eval_steps=config.num_eval_steps, ) # (clipped) perplexity after averaging log-perplexity eval_results['perplexity'] = jnp.clip( - jnp.exp(eval_results['loss']), max=1.0e4 + jnp.exp(eval_results['loss']), max=1.0e4 ) writer.write_scalars( - step, {'eval_' + k: v for k, v in eval_results.items()} + step, {'eval_' + k: v for k, v in eval_results.items()} ) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ( - step % config.checkpoint_every_steps == 0 or is_last_step + step % config.checkpoint_every_steps == 0 or is_last_step ) if config.save_checkpoints and save_checkpoint: logging.info('Saving checkpoint step %d.', step) diff --git a/examples/gemma/utils.py b/examples/gemma/utils.py index 18f6909cc..5162c0618 100644 --- a/examples/gemma/utils.py +++ b/examples/gemma/utils.py @@ -12,39 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copied over from MaxText (https://github.com/google/maxtext/blob/main/MaxText/max_utils.py). +# Copied over from MaxText +# (https://github.com/google/maxtext/blob/main/MaxText/max_utils.py). +"""Provides utilities for training the Flax gemma example.""" -import logging -from typing import Any, TYPE_CHECKING from collections.abc import Callable +import logging +from typing import Any +from flax import nnx +import transformer +from flax.training import train_state import jax +from jax.experimental import mesh_utils import jax.numpy as jnp import numpy as np -from jax.experimental import mesh_utils -from transformer import TransformerConfig, Transformer -from flax import nnx -from flax.training import train_state - -if TYPE_CHECKING: - from train import TrainConfig - Dtype = Any Shape = tuple[int, ...] class TrainState(train_state.TrainState): - graphdef: nnx.GraphDef[Transformer] + graphdef: nnx.GraphDef[transformer.Transformer] # Mesh utils. # ----------------------------------------------------------------------------- -def create_device_mesh(config: "TrainConfig"): - """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas.""" +def create_device_mesh(config: Any): + """Creates a device mesh with each slice in its own data parallel group. + + If there is only one slice, uses two replicas. + + Args: + config: The training configuration. + Returns: + The device mesh. + """ devices = jax.devices() num_devices = len(devices) try: @@ -52,58 +58,58 @@ def create_device_mesh(config: "TrainConfig"): except AttributeError: num_slices = 1 num_devices_per_slice = num_devices // num_slices - logging.info(f'Devices: {devices}') - logging.info(f'Number of devices: {num_devices}') + logging.info(f'Devices: {devices}') # pylint: disable=logging-fstring-interpolation + logging.info(f'Number of devices: {num_devices}') # pylint: disable=logging-fstring-interpolation multi_slice_env = hasattr(jax.devices()[0], 'slice_index') dcn_parallelism = [ - config.dcn_data_parallelism, - config.dcn_fsdp_parallelism, - config.dcn_tensor_parallelism, + config.dcn_data_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_tensor_parallelism, ] ici_parallelism = [ - config.ici_data_parallelism, - config.ici_fsdp_parallelism, - config.ici_tensor_parallelism, + config.ici_data_parallelism, + config.ici_fsdp_parallelism, + config.ici_tensor_parallelism, ] # Find possible unspecified parallelisms dcn_parallelism = fill_unspecified_mesh_axes( - dcn_parallelism, num_slices, 'DCN' + dcn_parallelism, num_slices, 'DCN' ) ici_parallelism = fill_unspecified_mesh_axes( - ici_parallelism, num_devices_per_slice, 'ICI' + ici_parallelism, num_devices_per_slice, 'ICI' ) if multi_slice_env: mesh = mesh_utils.create_hybrid_device_mesh( - ici_parallelism, dcn_parallelism + ici_parallelism, dcn_parallelism ) else: mesh = mesh_utils.create_device_mesh(ici_parallelism) - logging.info(f'Decided on mesh: {mesh}') - logging.info(f'Mesh shape: {mesh.shape}') + logging.info(f'Decided on mesh: {mesh}') # pylint: disable=logging-fstring-interpolation + logging.info(f'Mesh shape: {mesh.shape}') # pylint: disable=logging-fstring-interpolation return mesh def fill_unspecified_mesh_axes( - parallelism_vals, target_product, parallelism_type + parallelism_vals, target_product, parallelism_type ): - """Evaluates unspecified DCN/ICI parallelism values""" + """Evaluates unspecified DCN/ICI parallelism values.""" if -1 in parallelism_vals: assert parallelism_vals.count(-1) == 1, ( - f'Found unspecified values (-1) for more than one {parallelism_type} ' - ' parallelism axis. At most one axis can be unspecified.' + f'Found unspecified values (-1) for more than one {parallelism_type} ' + ' parallelism axis. At most one axis can be unspecified.' ) determined_val = target_product / np.prod(parallelism_vals) * -1 assert determined_val >= 1 and determined_val.is_integer, ( - 'Unspecified value unable to be determined with the given ' - f' {parallelism_type} parallelism values' + 'Unspecified value unable to be determined with the given ' + f' {parallelism_type} parallelism values' ) parallelism_vals[parallelism_vals.index(-1)] = int(determined_val) @@ -111,8 +117,8 @@ def fill_unspecified_mesh_axes( target_type = 'slices' if parallelism_type == 'DCN' else 'devices per slice' assert np.prod(parallelism_vals) == target_product, ( - f'Number of {target_type} {target_product} does not match the product' - f' of the {parallelism_type} parallelism {np.prod(parallelism_vals)}' + f'Number of {target_type} {target_product} does not match the product' + f' of the {parallelism_type} parallelism {np.prod(parallelism_vals)}' ) return parallelism_vals @@ -129,14 +135,15 @@ def _to_array(x): def setup_initial_state( - constructor: Callable[[TransformerConfig, jax.Array], Transformer], - tx, - config: TransformerConfig, - rng: jax.Array, - mesh: jax.sharding.Mesh, + constructor: Callable[ + [transformer.TransformerConfig, jax.Array], transformer.Transformer + ], + tx, + config: transformer.TransformerConfig, + rng: jax.Array, + mesh: jax.sharding.Mesh, ) -> tuple[TrainState, TrainState]: - """We initialize the model and optimizer state, and optionally load from a - checkpoint as necessary. + """We initialize train state, optionally loading from checkpoint. Args: constructor: the model constructor @@ -155,10 +162,10 @@ def sharded_init(): model = constructor(config, rng) graphdef, params = nnx.split(model, nnx.Param) state = TrainState.create( - apply_fn=graphdef.apply, - params=params, - tx=tx, - graphdef=graphdef, + apply_fn=graphdef.apply, + params=params, + tx=tx, + graphdef=graphdef, ) state = jax.tree.map(_to_array, state) state_spec = nnx.get_partition_spec(state)