diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt index 41971d290..1d711ba84 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt @@ -264,7 +264,7 @@ model.decoder.transformer.num_layers: 16 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_init.txt index 5c4658cf7..723be0837 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt index a27337377..f9ede244a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt @@ -264,7 +264,7 @@ model.decoder.transformer.num_layers: 16 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_init.txt index 5c4658cf7..723be0837 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt index 5cc38c163..b1c5963dc 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host.txt @@ -231,7 +231,7 @@ model.decoder.transformer.num_layers: 16 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host_init.txt index 5c4658cf7..723be0837 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-single-host_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host.txt new file mode 100644 index 000000000..f9cd08548 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host.txt @@ -0,0 +1,284 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 16 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 16 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 16 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 2048 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 32 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 16 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host_init.txt new file mode 100644 index 000000000..5c4658cf7 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash.txt new file mode 100644 index 000000000..f2a9b902d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash.txt @@ -0,0 +1,284 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 2048 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 2048 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 32 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 16 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash_init.txt new file mode 100644 index 000000000..5c4658cf7 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt new file mode 100644 index 000000000..8a9fd02aa --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host.txt @@ -0,0 +1,251 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 16 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 16 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 16 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 2048 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 32 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 16 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_init.txt new file mode 100644 index 000000000..5c4658cf7 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-single-host_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt new file mode 100644 index 000000000..38938b479 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken.txt @@ -0,0 +1,251 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 2048 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 2048 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 32 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 16 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_init.txt new file mode 100644 index 000000000..5c4658cf7 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt index 86c13eb79..691c7a6c6 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3.txt @@ -231,7 +231,7 @@ model.decoder.transformer.num_layers: 16 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3_init.txt index 5c4658cf7..723be0837 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt index 32be1295c..a95cf6a56 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host.txt @@ -264,7 +264,7 @@ model.decoder.transformer.num_layers: 28 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host_init.txt index b16c157f2..a42dd020d 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-single-host_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt index 3de7d2b95..f82ae8c5a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash.txt @@ -264,7 +264,7 @@ model.decoder.transformer.num_layers: 28 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash_init.txt index b16c157f2..a42dd020d 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt index 7cc3b4afc..f07237237 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host.txt @@ -231,7 +231,7 @@ model.decoder.transformer.num_layers: 28 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host_init.txt index b16c157f2..a42dd020d 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-single-host_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host.txt new file mode 100644 index 000000000..f0d0ac350 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host.txt @@ -0,0 +1,284 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 16 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 16 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 16 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 3072 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 24 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 28 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host_init.txt new file mode 100644 index 000000000..b16c157f2 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash.txt new file mode 100644 index 000000000..496892201 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash.txt @@ -0,0 +1,284 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 2048 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 3072 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 24 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 28 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash_init.txt new file mode 100644 index 000000000..b16c157f2 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt new file mode 100644 index 000000000..02c2501fe --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host.txt @@ -0,0 +1,251 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 16 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 16 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 16 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 3072 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 24 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 28 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_init.txt new file mode 100644 index 000000000..b16c157f2 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-single-host_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt new file mode 100644 index 000000000..3b0fa4cae --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken.txt @@ -0,0 +1,251 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 2048 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0003 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 8 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 3072 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim: 8192 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 24 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 28 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_init.txt new file mode 100644 index 000000000..b16c157f2 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(3072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 3072), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt index 612565b6f..7363cf779 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3.txt @@ -231,7 +231,7 @@ model.decoder.transformer.num_layers: 28 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3_init.txt index b16c157f2..a42dd020d 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 3072], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 40, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(3072, 24, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt index e3f269bfa..4e364f076 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt @@ -341,7 +341,7 @@ model.decoder.transformer.num_layers: 80 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash_init.txt index f0e1c9fec..8730d5928 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) @@ -7,4 +7,4 @@ decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt new file mode 100644 index 000000000..364237fdf --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt @@ -0,0 +1,361 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 2048 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v5litepod-256-4' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[1][0]: 'tpu-v6e-256-(4|8)' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[2][0]: 'tpu-v6e-256' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[2].grad_acc_steps: 4 +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' +mesh_rules[2][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[3][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[3][1][0]: 1 +mesh_rules[3][1][1]: -1 +mesh_rules[3][1][2]: 1 +mesh_rules[3][1][3]: 128 +mesh_rules[3][1][4]: 1 +mesh_rules[3][1][5]: 1 +mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: 1 +mesh_shape[2]: 1 +mesh_shape[3]: -1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 8192 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 64 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash_init.txt new file mode 100644 index 000000000..f0e1c9fec --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash_regularizer.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_regularizer.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash_regularizer.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt new file mode 100644 index 000000000..ac897293f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken.txt @@ -0,0 +1,328 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 983040 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 5000 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 983040 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 5000 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 2048 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 983040 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 5000 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 2048 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 2048 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 8192 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.00015 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 983040 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 983040 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_rules[0][0]: 'tpu-v5litepod-256-4' +mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[1][0]: 'tpu-v6e-256-(4|8)' +mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[2][0]: 'tpu-v6e-256' +mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[2][1].config_modifiers[2].grad_acc_steps: 4 +mesh_rules[2][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' +mesh_rules[2][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[3][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' +mesh_rules[3][1][0]: 1 +mesh_rules[3][1][1]: -1 +mesh_rules[3][1][2]: 1 +mesh_rules[3][1][3]: 128 +mesh_rules[3][1][4]: 1 +mesh_rules[3][1][5]: 1 +mesh_rules[4][0]: 'neuron-(trn2|trn2n).48xlarge-64' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 4 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: '.*[kqv]_proj|.*linear1_[01]' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: None +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_shape[0]: 1 +mesh_shape[1]: 1 +mesh_shape[2]: 1 +mesh_shape[3]: -1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 8192 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' +model.decoder.lm_head.param_partition_spec[0]: None +model.decoder.lm_head.param_partition_spec[1]: 'model' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 64 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 80 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 128256 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_init.txt new file mode 100644 index 000000000..f0e1c9fec --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_init.txt @@ -0,0 +1,10 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_regularizer.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_regularizer.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken_regularizer.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt index b7457c951..b82f94234 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt @@ -308,7 +308,7 @@ model.decoder.transformer.num_layers: 80 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 131072 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3_init.txt index f0e1c9fec..8730d5928 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3_init.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3_init.txt @@ -1,4 +1,4 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) @@ -7,4 +7,4 @@ decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file +decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt index 32982ee4c..a9994df1a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt @@ -7,14 +7,14 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -41,7 +41,7 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' @@ -164,20 +164,36 @@ mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[4][0]: 'tpu-v6e-256-(2|4|8)' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[5][0]: 'tpu-v5p-.*' mesh_rules[5][1][0]: 1 mesh_rules[5][1][1]: -1 mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[6][1][0]: 1 +mesh_rules[6][1][1]: -1 +mesh_rules[6][1][2]: 1 +mesh_rules[6][1][3]: 8 +mesh_rules[6][1][4]: 1 +mesh_rules[6][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 @@ -188,6 +204,7 @@ model.batch_axis_names[0]: 'data' model.batch_axis_names[1]: 'expert' model.batch_axis_names[2]: 'fsdp' model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' model.decoder.dim: 4096 model.decoder.dropout_rate: 0.0 model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' @@ -251,14 +268,11 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' model.decoder.transformer.layer.feed_forward.structure: 'prenorm' model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt index 74483207b..0e4dd19af 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt @@ -7,18 +7,18 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.global_batch_size: 2048 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -41,11 +41,11 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.global_batch_size: 2048 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 512 +input.batcher.global_batch_size: 2048 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' @@ -164,20 +164,36 @@ mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[4][0]: 'tpu-v6e-256-(2|4|8)' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[5][0]: 'tpu-v5p-.*' mesh_rules[5][1][0]: 1 mesh_rules[5][1][1]: -1 mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[6][1][0]: 1 +mesh_rules[6][1][1]: -1 +mesh_rules[6][1][2]: 1 +mesh_rules[6][1][3]: 8 +mesh_rules[6][1][4]: 1 +mesh_rules[6][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 @@ -188,6 +204,7 @@ model.batch_axis_names[0]: 'data' model.batch_axis_names[1]: 'expert' model.batch_axis_names[2]: 'fsdp' model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' model.decoder.dim: 4096 model.decoder.dropout_rate: 0.0 model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' @@ -251,14 +268,11 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' model.decoder.transformer.layer.feed_forward.structure: 'prenorm' model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt index 99be5fdb6..a07147b7e 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt @@ -7,14 +7,14 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -41,7 +41,7 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' @@ -164,20 +164,36 @@ mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[4][0]: 'tpu-v6e-256-(2|4|8)' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[5][0]: 'tpu-v5p-.*' mesh_rules[5][1][0]: 1 mesh_rules[5][1][1]: -1 mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[6][1][0]: 1 +mesh_rules[6][1][1]: -1 +mesh_rules[6][1][2]: 1 +mesh_rules[6][1][3]: 8 +mesh_rules[6][1][4]: 1 +mesh_rules[6][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 @@ -188,6 +204,7 @@ model.batch_axis_names[0]: 'data' model.batch_axis_names[1]: 'expert' model.batch_axis_names[2]: 'fsdp' model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' model.decoder.dim: 4096 model.decoder.dropout_rate: 0.0 model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' @@ -251,12 +268,11 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' model.decoder.transformer.layer.feed_forward.structure: 'prenorm' model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt index 4d32be7c7..375b91fd8 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt @@ -7,18 +7,18 @@ checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 983040 checkpointer.save_policy.min_step: 1 checkpointer.save_policy.n: 5000 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 983040 evalers['train'].eval_policy.min_step: 1 evalers['train'].eval_policy.n: 5000 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 512 +evalers['train'].input.batcher.global_batch_size: 2048 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -41,11 +41,11 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 983040 evalers['validation'].eval_policy.min_step: 1 evalers['validation'].eval_policy.n: 5000 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 512 +evalers['validation'].input.batcher.global_batch_size: 2048 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 512 +input.batcher.global_batch_size: 2048 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -102,12 +102,12 @@ learner.optimizer.args[1].learning_rate: 0.0003 learner.optimizer.args[1].update_schedule.alpha: 0.1 learner.optimizer.args[1].update_schedule.begin_value: 0.0 learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 +learner.optimizer.args[1].update_schedule.max_step: 983040 learner.optimizer.args[1].update_schedule.peak_lr: 1.0 learner.optimizer.args[1].update_schedule.warmup_steps: 2000 learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 983040 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' @@ -164,20 +164,36 @@ mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modif mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[4][0]: 'tpu-v6e-256-(2|4|8)' +mesh_rules[4][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' +mesh_rules[4][1].config_modifiers[0].mesh_shape[0]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[1]: -1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[2]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[3]: 256 +mesh_rules[4][1].config_modifiers[0].mesh_shape[4]: 1 +mesh_rules[4][1].config_modifiers[0].mesh_shape[5]: 1 +mesh_rules[4][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.save_and_offload_only_these_names_regex' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_offloaded: '.*([qkvo]_proj|context)' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].names_which_can_be_saved: None +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' +mesh_rules[4][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' +mesh_rules[4][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' +mesh_rules[5][0]: 'tpu-v5p-.*' mesh_rules[5][1][0]: 1 mesh_rules[5][1][1]: -1 mesh_rules[5][1][2]: 1 mesh_rules[5][1][3]: 8 mesh_rules[5][1][4]: 1 mesh_rules[5][1][5]: 1 +mesh_rules[6][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' +mesh_rules[6][1][0]: 1 +mesh_rules[6][1][1]: -1 +mesh_rules[6][1][2]: 1 +mesh_rules[6][1][3]: 8 +mesh_rules[6][1][4]: 1 +mesh_rules[6][1][5]: 1 mesh_shape[0]: 1 mesh_shape[1]: -1 mesh_shape[2]: 1 @@ -188,6 +204,7 @@ model.batch_axis_names[0]: 'data' model.batch_axis_names[1]: 'expert' model.batch_axis_names[2]: 'fsdp' model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' model.decoder.dim: 4096 model.decoder.dropout_rate: 0.0 model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' @@ -251,12 +268,11 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' model.decoder.transformer.layer.feed_forward.structure: 'prenorm' model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt similarity index 97% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt index a15dfdf0b..b48178e92 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt @@ -31,8 +31,8 @@ evalers['train'].input.source.max_sequence_length: 8192 evalers['train'].input.source.replace_newlines_with: '\n' evalers['train'].input.source.split: 'train[:8192]' evalers['train'].input.source.train_shuffle_buffer_size: 16384 -evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -58,8 +58,8 @@ evalers['validation'].input.source.max_sequence_length: 8192 evalers['validation'].input.source.replace_newlines_with: '\n' evalers['validation'].input.source.split: 'validation' evalers['validation'].input.source.train_shuffle_buffer_size: 16384 -evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -85,8 +85,8 @@ input.source.preprocessor.max_padding_fraction: 0.5 input.source.preprocessor.shuffle_buffer_size: 8192 input.source.preprocessor.window_size: 128 input.source.replace_newlines_with: '' -input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' klass: 'axlearn.common.trainer.SpmdTrainer' learner.ema.fn: 'axlearn.common.optimizers.param_ema' learner.enable_per_variable_summaries: False @@ -201,7 +201,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' model.decoder.lm_head.param_partition_spec[0]: None @@ -215,7 +215,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host_init.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_init.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host_init.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host_regularizer.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_regularizer.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host_regularizer.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt similarity index 97% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt index 7f520cbde..f9de9c5a6 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt @@ -31,8 +31,8 @@ evalers['train'].input.source.max_sequence_length: 8192 evalers['train'].input.source.replace_newlines_with: '\n' evalers['train'].input.source.split: 'train[:8192]' evalers['train'].input.source.train_shuffle_buffer_size: 16384 -evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -58,8 +58,8 @@ evalers['validation'].input.source.max_sequence_length: 8192 evalers['validation'].input.source.replace_newlines_with: '\n' evalers['validation'].input.source.split: 'validation' evalers['validation'].input.source.train_shuffle_buffer_size: 16384 -evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -85,8 +85,8 @@ input.source.preprocessor.max_padding_fraction: 0.5 input.source.preprocessor.shuffle_buffer_size: 8192 input.source.preprocessor.window_size: 128 input.source.replace_newlines_with: '' -input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' klass: 'axlearn.common.trainer.SpmdTrainer' learner.ema.fn: 'axlearn.common.optimizers.param_ema' learner.enable_per_variable_summaries: False @@ -201,7 +201,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' model.decoder.lm_head.param_partition_spec[0]: None @@ -215,7 +215,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash_init.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_init.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash_init.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash_regularizer.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_regularizer.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash_regularizer.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt similarity index 97% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt index 225299e7b..dddec1683 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host.txt @@ -31,8 +31,8 @@ evalers['train'].input.source.max_sequence_length: 8192 evalers['train'].input.source.replace_newlines_with: '\n' evalers['train'].input.source.split: 'train[:8192]' evalers['train'].input.source.train_shuffle_buffer_size: 16384 -evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -58,8 +58,8 @@ evalers['validation'].input.source.max_sequence_length: 8192 evalers['validation'].input.source.replace_newlines_with: '\n' evalers['validation'].input.source.split: 'validation' evalers['validation'].input.source.train_shuffle_buffer_size: 16384 -evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -85,8 +85,8 @@ input.source.preprocessor.max_padding_fraction: 0.5 input.source.preprocessor.shuffle_buffer_size: 8192 input.source.preprocessor.window_size: 128 input.source.replace_newlines_with: '' -input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' klass: 'axlearn.common.trainer.SpmdTrainer' learner.ema.fn: 'axlearn.common.optimizers.param_ema' learner.enable_per_variable_summaries: False @@ -201,7 +201,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' model.decoder.lm_head.param_partition_spec[0]: None @@ -215,7 +215,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_init.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_init.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_init.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-single-host_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt similarity index 97% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt index 6339517df..366507be9 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken.txt @@ -31,8 +31,8 @@ evalers['train'].input.source.max_sequence_length: 8192 evalers['train'].input.source.replace_newlines_with: '\n' evalers['train'].input.source.split: 'train[:8192]' evalers['train'].input.source.train_shuffle_buffer_size: 16384 -evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -58,8 +58,8 @@ evalers['validation'].input.source.max_sequence_length: 8192 evalers['validation'].input.source.replace_newlines_with: '\n' evalers['validation'].input.source.split: 'validation' evalers['validation'].input.source.train_shuffle_buffer_size: 16384 -evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -85,8 +85,8 @@ input.source.preprocessor.max_padding_fraction: 0.5 input.source.preprocessor.shuffle_buffer_size: 8192 input.source.preprocessor.window_size: 128 input.source.replace_newlines_with: '' -input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' klass: 'axlearn.common.trainer.SpmdTrainer' learner.ema.fn: 'axlearn.common.optimizers.param_ema' learner.enable_per_variable_summaries: False @@ -201,7 +201,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1. model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 +model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' model.decoder.lm_head.param_partition_spec[0]: None @@ -215,7 +215,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 model.decoder.output_norm.forward_dtype: None model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 +model.decoder.pad_token_id: 128004 model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_init.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_init.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_init.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_regularizer.txt new file mode 100644 index 000000000..65733fb7f --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken_regularizer.txt @@ -0,0 +1,11 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/lm_head/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken.txt new file mode 100644 index 000000000..5054cc1e7 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken.txt @@ -0,0 +1,249 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 3000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3000 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 500 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3000 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 1500 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 32 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 64 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3000 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 1500 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 32 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 64 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 32 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 64 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.3 +learner.optimizer.args[1].update_schedule: 1 +learner.optimizer.args[1].weight_decay: 0.01 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 5 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 8 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 16 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 2.6666666666666665 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 2 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 4 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 4 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' +vlog: 1 \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken_init.txt new file mode 100644 index 000000000..61615aa53 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash.txt new file mode 100644 index 000000000..1cd4387ca --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash.txt @@ -0,0 +1,286 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 3000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3000 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 500 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3000 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 1500 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 32 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 64 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3000 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 1500 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 32 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 64 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 32 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 64 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0006 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3000 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.01 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3000 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 8 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 16 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 2.6666666666666665 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 2 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None +model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.num_heads: 4 +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 4 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_init.txt new file mode 100644 index 000000000..61615aa53 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken.txt new file mode 100644 index 000000000..32cac14f7 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken.txt @@ -0,0 +1,253 @@ +batch_axis_names[0]: 'data' +batch_axis_names[1]: 'expert' +batch_axis_names[2]: 'fsdp' +batch_axis_names[3]: 'seq' +checkpointer.gc_loop_interval_seconds: 60 +checkpointer.keep_every_n_steps: 3000 +checkpointer.keep_last_n: 3 +checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' +checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' +checkpointer.save_policy.max_step: 3000 +checkpointer.save_policy.min_step: 1 +checkpointer.save_policy.n: 500 +checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' +checkpointer.storage.timeout_secs: 3600 +evalers['train'].eval_dtype: 'jax.numpy.bfloat16' +evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['train'].eval_policy.max_step: 3000 +evalers['train'].eval_policy.min_step: 1 +evalers['train'].eval_policy.n: 1500 +evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['train'].input.batcher.global_batch_size: 32 +evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['train'].input.batcher.prefetch_buffer_size: -1 +evalers['train'].input.is_training: False +evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['train'].input.source.is_training: False +evalers['train'].input.source.max_sequence_length: 64 +evalers['train'].input.source.replace_newlines_with: '\n' +evalers['train'].input.source.split: 'train[:8192]' +evalers['train'].input.source.train_shuffle_buffer_size: 16384 +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['train'].metric_calculator.model_method: 'forward' +evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['train'].summary_writer.write_every_n_steps: 1 +evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' +evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' +evalers['validation'].eval_policy.max_step: 3000 +evalers['validation'].eval_policy.min_step: 1 +evalers['validation'].eval_policy.n: 1500 +evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' +evalers['validation'].input.batcher.global_batch_size: 32 +evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +evalers['validation'].input.batcher.prefetch_buffer_size: -1 +evalers['validation'].input.is_training: False +evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' +evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' +evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' +evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' +evalers['validation'].input.source.is_training: False +evalers['validation'].input.source.max_sequence_length: 64 +evalers['validation'].input.source.replace_newlines_with: '\n' +evalers['validation'].input.source.split: 'validation' +evalers['validation'].input.source.train_shuffle_buffer_size: 16384 +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' +evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' +evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' +evalers['validation'].metric_calculator.model_method: 'forward' +evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +evalers['validation'].summary_writer.write_every_n_steps: 1 +input.batcher.fn: 'axlearn.common.input_tf_data.batch' +input.batcher.global_batch_size: 32 +input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' +input.batcher.prefetch_buffer_size: -1 +input.is_training: True +input.klass: 'axlearn.common.input_tf_data.Input' +input.processor.fn: 'axlearn.common.input_tf_data.identity' +input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' +input.source.data_mixture_components[0]['weight']: 1.0 +input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 +input.source.data_mixture_components[0]['split']: 'train' +input.source.data_mixture_components[0]['info']: '' +input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' +input.source.max_sequence_length: 64 +input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' +input.source.preprocessor.max_padding_fraction: 0.5 +input.source.preprocessor.shuffle_buffer_size: 8192 +input.source.preprocessor.window_size: 128 +input.source.replace_newlines_with: '' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' +klass: 'axlearn.common.trainer.SpmdTrainer' +learner.ema.fn: 'axlearn.common.optimizers.param_ema' +learner.enable_per_variable_summaries: False +learner.klass: 'axlearn.common.learner.Learner' +learner.optimizer.args[0].eps: 1e-08 +learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' +learner.optimizer.args[0].max_norm: 1 +learner.optimizer.args[1].b1: 0.9 +learner.optimizer.args[1].b2: 0.95 +learner.optimizer.args[1].eps: 1e-08 +learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' +learner.optimizer.args[1].learning_rate: 0.0006 +learner.optimizer.args[1].update_schedule.alpha: 0.1 +learner.optimizer.args[1].update_schedule.begin_value: 0.0 +learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' +learner.optimizer.args[1].update_schedule.max_step: 3000 +learner.optimizer.args[1].update_schedule.peak_lr: 1.0 +learner.optimizer.args[1].update_schedule.warmup_steps: 2000 +learner.optimizer.args[1].weight_decay: 0.01 +learner.optimizer.fn: 'axlearn.common.optimizers.chain' +max_step: 3000 +mesh_axis_names[0]: 'pipeline' +mesh_axis_names[1]: 'data' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'seq' +mesh_axis_names[5]: 'model' +mesh_shape[0]: 1 +mesh_shape[1]: -1 +mesh_shape[2]: 1 +mesh_shape[3]: 1 +mesh_shape[4]: 1 +mesh_shape[5]: 1 +model.batch_axis_names[0]: 'data' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' +model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 8 +model.decoder.dropout_rate: 0.0 +model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' +model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.decoder.emb.token_emb.param_partition_spec[0]: None +model.decoder.emb.token_emb.param_partition_spec[1]: 'model' +model.decoder.eos_token_id: 128001 +model.decoder.klass: 'axlearn.common.decoder.Decoder' +model.decoder.logits_partition_spec[0][0]: 'data' +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' +model.decoder.logits_partition_spec[2]: 'model' +model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.output_norm.eps: 1e-05 +model.decoder.output_norm.forward_dtype: None +model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.pad_token_id: 128004 +model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' +model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' +model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' +model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 16 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 2.6666666666666665 +model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' +model.decoder.transformer.layer.feed_forward.linear1.bias: False +model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.bias: False +model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' +model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 +model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None +model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 +model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.feed_forward.structure: 'prenorm' +model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' +model.decoder.transformer.layer.remat_spec['prevent_cse']: False +model.decoder.transformer.layer.remat_spec['policy'].fn: 'axlearn.common.attention._save_and_offload_only_these_names_regex' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded: None +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved: '.*([qkvo]_proj|context)' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' +model.decoder.transformer.layer.self_attention.attention.causal: True +model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 2 +model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' +model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 +model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False +model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' +model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' +model.decoder.transformer.layer.self_attention.attention.num_heads: 4 +model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False +model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None +model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' +model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' +model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' +model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 +model.decoder.transformer.layer.self_attention.norm.forward_dtype: None +model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' +model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' +model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' +model.decoder.transformer.layer.self_attention.structure: 'prenorm' +model.decoder.transformer.num_layers: 4 +model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' +model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' +model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' +model.decoder.vocab_size: 32 +model.dtype: 'jax.numpy.float32' +model.klass: 'axlearn.common.causal_lm.Model' +model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' +model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' +model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' +model.param_init.init_by_param_name['.*weight$'].scale: 1.0 +model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names[0]: 'seq' +model.z_loss_scale: 0.0 +name: 'gpt_trainer' +prune_empty_state_updates: True +save_input_iterator: False +start_trace_process_indices[0]: 0 +summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' +summary_writer.max_queue: 1000 +summary_writer.write_every_n_steps: 100 +train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_init.txt new file mode 100644 index 000000000..61615aa53 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_regularizer.txt new file mode 100644 index 000000000..03fb7437d --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken_regularizer.txt @@ -0,0 +1,10 @@ +====================weight_decay_scale root.optimizer==================== +decoder/emb/token_emb/weight: 1 +decoder/output_norm/scale: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 +decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 +decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 +decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.1-70B.json b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.1-70B.json deleted file mode 100644 index 9d03fc1e0..000000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.1-70B.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 128000, - "eos_token_id": 128001, - "hidden_act": "silu", - "hidden_size": 8192, - "initializer_range": 0.02, - "intermediate_size": 28672, - "max_position_embeddings": 131072, - "mlp_bias": false, - "model_type": "llama", - "num_attention_heads": 64, - "num_hidden_layers": 80, - "num_key_value_heads": 8, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": { - "factor": 8.0, - "low_freq_factor": 1.0, - "high_freq_factor": 4.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3" - }, - "rope_theta": 500000.0, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.43.0.dev0", - "use_cache": true, - "vocab_size": 128256 -} diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.1-8B.json b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.1-8B.json deleted file mode 100644 index cccf055d6..000000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.1-8B.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 128000, - "eos_token_id": 128001, - "hidden_act": "silu", - "hidden_size": 4096, - "initializer_range": 0.02, - "intermediate_size": 14336, - "max_position_embeddings": 131072, - "mlp_bias": false, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 8, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": { - "factor": 8.0, - "low_freq_factor": 1.0, - "high_freq_factor": 4.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3" - }, - "rope_theta": 500000.0, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.43.0.dev0", - "use_cache": true, - "vocab_size": 128256 -} diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.2-1B.json b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.2-1B.json deleted file mode 100644 index 83b8b2aeb..000000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.2-1B.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 128000, - "eos_token_id": 128001, - "head_dim": 64, - "hidden_act": "silu", - "hidden_size": 2048, - "initializer_range": 0.02, - "intermediate_size": 8192, - "max_position_embeddings": 131072, - "mlp_bias": false, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 16, - "num_key_value_heads": 8, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": { - "factor": 32.0, - "high_freq_factor": 4.0, - "low_freq_factor": 1.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3" - }, - "rope_theta": 500000.0, - "tie_word_embeddings": true, - "torch_dtype": "bfloat16", - "transformers_version": "4.45.0.dev0", - "use_cache": true, - "vocab_size": 128256 -} diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.2-3B.json b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.2-3B.json deleted file mode 100644 index 47d4a5aa6..000000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.param_converter_test/Llama-3.2-3B.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 128000, - "eos_token_id": 128001, - "head_dim": 128, - "hidden_act": "silu", - "hidden_size": 3072, - "initializer_range": 0.02, - "intermediate_size": 8192, - "max_position_embeddings": 131072, - "mlp_bias": false, - "model_type": "llama", - "num_attention_heads": 24, - "num_hidden_layers": 28, - "num_key_value_heads": 8, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": { - "factor": 32.0, - "high_freq_factor": 4.0, - "low_freq_factor": 1.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3" - }, - "rope_theta": 500000.0, - "tie_word_embeddings": true, - "torch_dtype": "bfloat16", - "transformers_version": "4.45.0.dev0", - "use_cache": true, - "vocab_size": 128256 -} diff --git a/axlearn/experiments/text/gpt/c4_trainer.py b/axlearn/experiments/text/gpt/c4_trainer.py index 8c70422e3..f48af19f9 100644 --- a/axlearn/experiments/text/gpt/c4_trainer.py +++ b/axlearn/experiments/text/gpt/c4_trainer.py @@ -41,22 +41,29 @@ """ -from axlearn.common.config import InstantiableConfig, config_for_function +from axlearn.common.config import InstantiableConfig, config_for_class, config_for_function from axlearn.common.input_lm import lm_text_preprocessor from axlearn.common.utils import get_data_dir from axlearn.experiments.text.common import DataMixtureComponent, vocab from axlearn.experiments.text.gpt import fuji, gspmd from axlearn.experiments.text.gpt.common import mixture_train_input_source, tfds_input +from axlearn.experiments.text.gpt.vocabulary_fuji_v3 import FujiV3Vocabulary from axlearn.experiments.trainer_config_utils import TrainerConfigFn -# Sentencepiece vocabs generated from c4/en:3.0.1. -# See bpe_{32k,128k}.json for the sentencepiece settings. -_SENTENCEPIECE_MODEL_NAME = { - 32 * 1024: "bpe_32k_c4.model", - # TikToken is not yet supported, so we are using sentencepiece for now. - # Our new grain-based inputs can support TikToken in the future. - 128256: "bpe_128k_c4.model", -} + +def _vocab_cfg(vocab_size: int): + if vocab_size == 32 * 1024: + # Sentencepiece vocabs generated from c4/en:3.0.1. + # See bpe_{32k,128k}.json for the sentencepiece settings. + return config_for_function(vocab).set(sentencepiece_model_name="bpe_32k_c4.model") + if vocab_size == 128 * 1024: + return config_for_function(vocab).set(sentencepiece_model_name="bpe_128k_c4.model") + if vocab_size == 128256: + # TikToken. + return config_for_class(FujiV3Vocabulary).set(filename="Llama-3-tokenizer.json") + raise ValueError(f"Tokenizer with vocab size {vocab_size} does not exist.") + + _train_data_mixture_components = [ DataMixtureComponent( name="c4/en:3.0.1", @@ -75,9 +82,7 @@ def _eval_input_sources( dataset_name="c4/en:3.0.1", split=split, is_training=False, - vocab_cfg=config_for_function(vocab).set( - sentencepiece_model_name=_SENTENCEPIECE_MODEL_NAME[vocab_size] - ), + vocab_cfg=_vocab_cfg(vocab_size), max_sequence_length=max_sequence_length, ) for name, split in (("train", "train[:8192]"), ("validation", "validation")) @@ -87,9 +92,7 @@ def _eval_input_sources( def _train_input_source(*, vocab_size: int, max_sequence_length: int) -> InstantiableConfig: source_cfg = config_for_function(mixture_train_input_source).set( data_mixture_components=_train_data_mixture_components, - vocab_cfg=config_for_function(vocab).set( - sentencepiece_model_name=_SENTENCEPIECE_MODEL_NAME[vocab_size] - ), + vocab_cfg=_vocab_cfg(vocab_size), max_sequence_length=max_sequence_length, preprocessor=config_for_function(lm_text_preprocessor).set(max_padding_fraction=0.5), ) diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index c01789a39..3d433094c 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -223,6 +223,8 @@ def model_config( ffn_structure: str = "prenorm", atten_structure: str = "prenorm", atten_logit_cap: Optional[float] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, ) -> causal_lm.Model.Config: """Returns an LM model config based on the given hyperparams. @@ -253,6 +255,8 @@ def model_config( atten_logit_cap: Cap the absolute values of logits by tanh. Enabled by setting a positive value. remat_offload_dst: Destination of remat checkptoing offloading. + pad_token_id: Int ID of the inputs to be masked for self-attention. + eos_token_id: Int ID of the end of sequence token id. Returns: A causal LM config. @@ -283,6 +287,10 @@ def model_config( lm_head=lm_head_cfg, dropout_rate=dropout_rate, ) + if pad_token_id is not None: + decoder_cfg.set(pad_token_id=pad_token_id) + if eos_token_id is not None: + decoder_cfg.set(eos_token_id=eos_token_id) # Model. model_param_init = DefaultInitializer.default_config().set( init_by_param_name={ diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 259e6a960..bbd769dad 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -65,13 +65,15 @@ class Version(enum.Enum): V1 = 1 V2 = 2 V3 = 3 + V3_TIKTOKEN = "3-tiktoken" # Mapping from Fuji versions to vocab sizes. VOCAB_SIZE = { Version.V1: 32 * 1024, Version.V2: 32 * 1024, - Version.V3: 128256, + Version.V3: 128 * 1024, + Version.V3_TIKTOKEN: 128256, } @@ -80,6 +82,7 @@ class Version(enum.Enum): Version.V1: 2048, Version.V2: 4096, Version.V3: 8192, + Version.V3_TIKTOKEN: 8192, } @@ -87,6 +90,7 @@ class Version(enum.Enum): Version.V1: 1e4, Version.V2: 1e4, Version.V3: 5e5, + Version.V3_TIKTOKEN: 5e5, } # Mapping from Fuji versions to total number of tokens used in training. @@ -102,6 +106,13 @@ class Version(enum.Enum): "70B": 2 * (1024**4), # 2T tokens }, Version.V3: { + "test": 15 * (1024**4), # 15T tokens + "1B": 15 * (1024**4), # 15T tokens + "3B": 15 * (1024**4), # 15T tokens + "7B": 15 * (1024**4), # 15T tokens + "70B": 15 * (1024**4), # 15T tokens + }, + Version.V3_TIKTOKEN: { "test": 15 * (1024**4), # 15T tokens "1B": 15 * (1024**4), # 15T tokens "3B": 15 * (1024**4), # 15T tokens @@ -116,6 +127,7 @@ class Version(enum.Enum): Version.V1: 4 * (1024**2), Version.V2: 4 * (1024**2), Version.V3: 16 * (1024**2), + Version.V3_TIKTOKEN: 16 * (1024**2), } @@ -128,15 +140,13 @@ def get_trainer_kwargs( ) -> dict[str, Any]: """Construct default trainer kwargs given a model size.""" tokens_per_batch = TOKENS_PER_BATCH[version] - if model_size not in TOTAL_TOKENS[version]: - return {} max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch max_sequence_length = MAX_SEQUENCE_LENGTH[version] train_batch_size = tokens_per_batch // max_sequence_length # Whether to use grouped query attention. num_kv_heads = None - if version == Version.V3: + if version in (Version.V3, Version.V3_TIKTOKEN): num_kv_heads = 8 rope_theta = ROPE_THETA[version] @@ -530,6 +540,9 @@ def get_trainer_kwargs( raise NotImplementedError(f"Unknown model size {model_size}.") model_kwargs = trainer_kwargs.pop("model_kwargs") model_kwargs.setdefault("vocab_size", vocab_size) + if version == Version.V3_TIKTOKEN: # tiktoken tokenizer + model_kwargs["pad_token_id"] = 128004 + model_kwargs["eos_token_id"] = 128001 trainer_kwargs["model_cfg"] = model_config(**model_kwargs) trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config( max_step=trainer_kwargs["max_step"], @@ -552,6 +565,8 @@ def model_config( ffn_dim: Optional[Union[int, config.FunctionConfigBase]] = None, flash_attention: bool = False, stack_cfg: Optional[BaseStackedTransformerLayer.Config] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, ) -> causal_lm.Model.Config: """Returns an LM model config based on the given hyperparams. @@ -570,6 +585,8 @@ def model_config( flash_attention: Whether to enable flash attention. stack_cfg: The transformer stack config. If None, defaults to a RepeatedTransformerLayer. + pad_token_id: Int ID of the inputs to be masked for self-attention. + eos_token_id: Int ID of the end of sequence token id. Returns: A causal LM config. @@ -607,6 +624,8 @@ def model_config( lm_head_cfg=LmHead.default_config() if not shared_lm_head else None, attention_cfg=flash_attention_config() if flash_attention else atten_cfg, attention_qkv_linear=atten_qkv_linear, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, ) return cfg @@ -625,6 +644,8 @@ def trainer_configs( for version, model_size, flash_attention in itertools.product( Version, MODEL_SIZES, [True, False] ): + if model_size not in TOTAL_TOKENS[version]: # This combination does not exist. + continue vocab_size = VOCAB_SIZE[version] config_name = make_config_name( arch=arch, @@ -635,8 +656,6 @@ def trainer_configs( kwargs = get_trainer_kwargs( model_size, vocab_size=vocab_size, version=version, flash_attention=flash_attention ) - if len(kwargs) == 0: # This combination does not exist - continue max_sequence_length = kwargs.pop("max_sequence_length") # pylint: disable-next=unexpected-keyword-arg,missing-kwoa config_map[config_name] = get_trainer_config_fn( @@ -690,9 +709,13 @@ def make_single_host_config(base_config_name: str) -> SpmdTrainer.Config: # The original config was supposed to run on >= 32 machines. # pylint: disable=cell-var-from-loop - cfg.input.batcher.global_batch_size //= 128 if version == Version.V3 else 32 + cfg.input.batcher.global_batch_size //= ( + 128 if version in (Version.V3, Version.V3_TIKTOKEN) else 32 + ) for evaler in cfg.evalers.values(): - evaler.input.batcher.global_batch_size //= 128 if version == Version.V3 else 32 + evaler.input.batcher.global_batch_size //= ( + 128 if version in (Version.V3, Version.V3_TIKTOKEN) else 32 + ) # pylint: enable=cell-var-from-loop return cfg diff --git a/axlearn/experiments/text/gpt/param_converter_test.py b/axlearn/experiments/text/gpt/param_converter_test.py index ce4de4bf6..3d8bd2c2d 100644 --- a/axlearn/experiments/text/gpt/param_converter_test.py +++ b/axlearn/experiments/text/gpt/param_converter_test.py @@ -9,7 +9,7 @@ import pytest import torch from absl.testing import absltest, parameterized -from transformers import AutoConfig +from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaForCausalLM from axlearn.common import utils @@ -27,6 +27,69 @@ # Use cpu for the test. jax.config.update("jax_platform_name", "cpu") +# Parameters are based on https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json +config_dict_1b = { + "vocab_size": 128256, + "hidden_size": 2048, + "intermediate_size": 8192, + "num_hidden_layers": 16, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "hidden_act": "silu", + "max_position_embeddings": 131072, + "initializer_range": 0.02, + "rms_norm_eps": 1e-5, + "use_cache": True, + "bos_token_id": 128000, + "eos_token_id": 128001, + "pretraining_tp": 1, + "tie_word_embeddings": True, + "rope_theta": 500000.0, + "rope_scaling": { + "factor": 32.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + "attention_bias": False, + "attention_dropout": 0.0, + "mlp_bias": False, + "torch_dtype": "bfloat16", + "architectures": ["LlamaForCausalLM"], +} +# Parameters are based on https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json +config_dict_3b = {"hidden_size": 3072, "num_attention_heads": 24, "num_hidden_layers": 28} +# Parameters are based on https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json +config_dict_8b = { + "hidden_size": 4096, + "intermediate_size": 14336, + "num_hidden_layers": 32, + "rope_scaling": { + "factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + "tie_word_embeddings": False, +} +# Parameters are based on https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json +config_dict_70b = { + "hidden_size": 8192, + "intermediate_size": 28672, + "num_attention_heads": 64, + "num_hidden_layers": 80, + "rope_scaling": { + "factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + "tie_word_embeddings": False, +} + def compute_fuji_grad(prng_key, fuji: Model, state: NestedTensor, input_batch: NestedTensor): """Compute gradient of fuji model with a pseudo loss.""" @@ -76,25 +139,13 @@ def compute_llama_grad(llama, torch_ids, state): class FujiConvertStateTest(TestCase): @parameterized.parameters( - dict( - fuji_model_name="fuji-1B-v3", - llama_model_name="Llama-3.2-1B", - ), - dict( - fuji_model_name="fuji-3B-v3", - llama_model_name="Llama-3.2-3B", - ), - dict( - fuji_model_name="fuji-8B-v3", - llama_model_name="Llama-3.1-8B", - ), - dict( - fuji_model_name="fuji-70B-v3", - llama_model_name="Llama-3.1-70B", - ), + dict(fuji_model_name="fuji-1B-v3-tiktoken"), + dict(fuji_model_name="fuji-3B-v3-tiktoken"), + dict(fuji_model_name="fuji-8B-v3-tiktoken"), + dict(fuji_model_name="fuji-70B-v3-tiktoken"), ) @pytest.mark.high_cpu - def test_weight_loading(self, fuji_model_name, llama_model_name): + def test_weight_loading(self, fuji_model_name): trainer_config_map = c4_trainer.named_trainer_configs() trainer_config_fn = trainer_config_map[fuji_model_name] trainer_config = trainer_config_fn() @@ -103,17 +154,14 @@ def test_weight_loading(self, fuji_model_name, llama_model_name): fuji: Model = model_config.instantiate(parent=None) prng_key = jax.random.PRNGKey(0) state = fuji.initialize_parameters_recursively(prng_key=prng_key) - config = AutoConfig.from_pretrained( - os.path.join( - dir_path, - "..", - "..", - "testdata", - "axlearn.experiments.text.gpt.param_converter_test", - f"{llama_model_name}.json", - ), - local_files_only=True, - ) + config_dict = config_dict_1b + if fuji_model_name == "fuji-3B-v3-tiktoken": + config_dict.update(config_dict_3b) + elif fuji_model_name == "fuji-8B-v3-tiktoken": + config_dict.update(config_dict_8b) + elif fuji_model_name == "fuji-70B-v3-tiktoken": + config_dict.update(config_dict_70b) + config = LlamaConfig(**config_dict) llama = LlamaForCausalLM._from_config(config) # pylint: disable=W0212 llama = llama.eval() ids = jax.random.randint(jax.random.PRNGKey(123), shape=(2, 2), minval=0, maxval=128256) @@ -135,13 +183,13 @@ def test_weight_loading(self, fuji_model_name, llama_model_name): llama_logits = output.logits.numpy() # The difference is caused by the SDPA attention layer. The deeper the larger the error. - if fuji_model_name == "fuji-1B-v3": + if fuji_model_name == "fuji-1B-v3-tiktoken": atol = 2e-3 - elif fuji_model_name == "fuji-3B-v3": + elif fuji_model_name == "fuji-3B-v3-tiktoken": atol = 2e-2 - elif fuji_model_name == "fuji-8B-v3": + elif fuji_model_name == "fuji-8B-v3-tiktoken": atol = 2e-1 - elif fuji_model_name == "fuji-70B-v3": + elif fuji_model_name == "fuji-70B-v3-tiktoken": atol = 2.0 else: atol = 2e-3 diff --git a/axlearn/experiments/text/gpt/vocabulary_fuji_v3.py b/axlearn/experiments/text/gpt/vocabulary_fuji_v3.py new file mode 100644 index 000000000..76f0ad7a3 --- /dev/null +++ b/axlearn/experiments/text/gpt/vocabulary_fuji_v3.py @@ -0,0 +1,222 @@ +# Copyright © 2024 Apple Inc. + +"""Fuji v3 vocabulary.""" + +import os +import tempfile +from typing import Optional, Protocol, Sequence, Union + +import jax +import numpy as np +import tensorflow.compat.v2 as tf +from tokenizers import Tokenizer + +import axlearn.common.file_system as fs +from axlearn.common.utils import get_data_dir + + +class InnerTokenizer(Protocol): + """Defines a protocol of InnerTokenizer which is used in Vocabulary. + + This is a subset of sentencepiece_processor.SentencePieceProcessor API that are used in + Vocabulary. + """ + + def encode_as_pieces(self, pieces: str) -> list[str]: + """Encode text input to tokens.""" + pass + + def piece_to_id(self, piece: str) -> int: + """Encode a token to id.""" + pass + + +class Vocabulary(Protocol): + """Defines a protocol of Vocabulary. + + This is a subset of seqio.Vocabulary APIs that are used in text_to_lm_training_input and + test_to_lm_eval_input. + """ + + @property + def pad_id(self) -> int: + pass + + @property + def eos_id(self) -> Optional[int]: + pass + + def encode_tf(self, s: tf.Tensor) -> tf.Tensor: + """Tokenizes string Scalar to an int32 Tensor, without adding EOS.""" + pass + + def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor: + """Detokenizes int32 batched Tensor.""" + pass + + def encode(self, s: str) -> list[int]: + """Tokenizes string to an int sequence, without adding EOS.""" + pass + + def _decode(self, ids: Sequence[int]) -> str: + """Detokenizes int sequence to a string, through all EOS.""" + pass + + def decode(self, ids: Sequence[int]) -> str: + """Detokenizes int32 iterable to a string, up through first EOS.""" + pass + + @property + def tokenizer(self) -> InnerTokenizer: + pass + + +class FujiInnerTokenizer: + """A wrapper for tokenizer.Tokenizer so that it follows InnerTokenizer Protocol.""" + + def __init__(self, tokenizer): + self._tokenizer = tokenizer + + def encode_as_pieces(self, pieces: str) -> list[str]: + """Encode text input to tokens.""" + return self._tokenizer.encode(pieces, add_special_tokens=False).tokens + + def piece_to_id(self, piece: str) -> int: + """Encode a token to id.""" + return self._tokenizer.token_to_id(piece) + + +class FujiV3Vocabulary: + """A wrapper for tokenizers.Tokenizer so that it follows Vocabulary Protocol. + + Although its name has fuji, but it can be extended to work for all tokenizers.Tokenizer. + """ + + def __init__(self, filename: str): + data_dir = get_data_dir() + data_dir = ( + os.path.join(os.path.dirname(__file__), "..", "..", "..", "data") + if data_dir is None or data_dir == "FAKE" + else data_dir + ) + filename = os.path.join(data_dir, "tokenizers", "hf", filename) + if filename.startswith("gs:") or filename.startswith("s3:"): + # Create a different file for each usage. + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "tokenizer.json") + fs.copy(filename, path) + self._tokenizer = Tokenizer.from_file(path) + else: + self._tokenizer = Tokenizer.from_file(filename) + self.vocab = self._tokenizer.get_vocab() + self.tokenizer = FujiInnerTokenizer(self._tokenizer) + + @property + def pad_id(self) -> int: + # Some tokenizers do not have a pad_id. + # https://discuss.huggingface.co/t/how-to-set-the-pad-token-for-meta-llama-llama-3-models/103418 + for token in ("<|pad_id|>", "<|finetune_right_pad_id|>"): + if token in self.vocab: + return self.vocab[token] + raise ValueError("Unable to infer pad token.") + + @property + def eos_id(self) -> Optional[int]: + if "<|end_of_text|>" in self.vocab: + return self.vocab["<|end_of_text|>"] + raise ValueError("Unable to infer eos token.") + + @property + def bos_id(self) -> Optional[int]: + if "<|begin_of_text|>" in self.vocab: + return self.vocab["<|begin_of_text|>"] + raise ValueError("Unable to infer eos token.") + + def _encode_tf(self, s: tf.Tensor) -> tf.Tensor: + """Encodes a string to token IDs. + + Args: + s: A tf.Tensor of shape () or (n,) and dtype tf.string. + + Returns: + A tf.Tensor or RaggedTensor of shape (num_tokens,) or (n, None) and dtype tf.int32. + """ + need_unpack = False + if s.ndim == 0: + s = tf.reshape(s, (1,)) + need_unpack = True + + def helper_en(s): + res = [] + for item in s.numpy(): + item = item.decode("utf-8") + encoded = self._tokenizer.encode(item, add_special_tokens=True) + ids = encoded.ids + # The return does not include EOS, but we need to remove BOS. + if len(ids) > 0 and ids[0] == self.bos_id: + ids = ids[1:] + res.append(ids) + return tf.ragged.constant(res, dtype=tf.int32) + + ret = tf.py_function( + helper_en, inp=[s], Tout=tf.RaggedTensorSpec([None, None], dtype=tf.int32) + ) + if need_unpack: + return ret[0] + else: + return ret + + def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor: + """Detokenizes int32 batched Tensor.""" + need_unpack = False + if len(ids.shape) == 1: + ids = tf.reshape(ids, (1, -1)) + need_unpack = True + + def helper(ids): + ids = [ids[i].numpy() for i in range(ids.shape[0])] + ids = [ + item[(item != self.bos_id) & (item != self.eos_id) & (item != self.pad_id)] + for item in ids + ] + s = self._tokenizer.decode_batch(ids, skip_special_tokens=False) + return tf.convert_to_tensor(s, dtype=tf.string) + + ret = tf.py_function(helper, inp=[ids], Tout=tf.string) + ret.set_shape(tf.TensorShape((None,))) + if need_unpack: + return ret[0] + else: + return ret + + def encode_tf(self, s: tf.Tensor) -> tf.Tensor: + """Tokenizes string Scalar to an int32 Tensor, without adding EOS. + + Args: + s: A tf.Tensor of shape () or (n,) and dtype tf.string. + + Returns: + A tf.Tensor or RaggedTensor of shape (num_tokens,) or (n, None) and dtype tf.int32. + """ + return self._encode_tf(s) + + def encode(self, s: str) -> list[int]: + """Tokenizes string to an int sequence, without adding EOS.""" + ret = self._tokenizer.encode(s, add_special_tokens=True).ids + # The return does not include EOS, but we need to remove BOS. + return ret[1:] if ret[0] == self.bos_id else ret + + def _decode(self, ids: Union[list[int], tuple[int]]) -> str: + """Detokenizes int32 iterable to a string.""" + # remove BOS, EOS and PAD. + ids = np.array(ids) + ids = ids[(ids != self.bos_id) & (ids != self.eos_id) & (ids != self.pad_id)] + return self._tokenizer.decode(ids, skip_special_tokens=False) + + def decode(self, ids: Union[list[int], tuple[int], jax.Array, np.ndarray]) -> str: + """Detokenizes int32 iterable to a string, up through first EOS.""" + if self.eos_id is not None and self.eos_id in ids: + if isinstance(ids, (jax.Array, np.ndarray)): + ids = ids.tolist() # type: ignore + ids = ids[: ids.index(self.eos_id) + 1] + return self._decode(ids) diff --git a/axlearn/experiments/text/gpt/vocabulary_fuji_v3_test.py b/axlearn/experiments/text/gpt/vocabulary_fuji_v3_test.py new file mode 100644 index 000000000..dcbac41f7 --- /dev/null +++ b/axlearn/experiments/text/gpt/vocabulary_fuji_v3_test.py @@ -0,0 +1,201 @@ +# Copyright © 2024 Apple Inc. + +"""Tests fuji v3 vocabulary.""" + +import numpy as np +import pytest +import tensorflow.compat.v2 as tf +from absl.testing import parameterized + +from axlearn.common import input_text, input_tf_data +from axlearn.common.config import config_for_class, config_for_function +from axlearn.common.input_lm import ( + PackingMethodType, + lm_text_preprocessor, + text_to_lm_eval_input, + text_to_lm_training_input, +) +from axlearn.common.input_text_test import make_ds_fn +from axlearn.common.test_utils import TestCase +from axlearn.experiments.text.gpt.vocabulary_fuji_v3 import FujiV3Vocabulary + + +@pytest.mark.skip(reason="no tokenizer file.") +class FujiV3VocabularyTest(TestCase): + """Tests FujiV3VocabularyTest.""" + + @property + def vocab_cfg(self): + return config_for_class(FujiV3Vocabulary).set(filename="Llama-3-tokenizer.json") + + def test_encode_tf_and_decode_tf(self): + vocab = self.vocab_cfg.instantiate() + text = tf.constant( + "Lorem ipsum dolor sit amet, consectetur adipiscing elit\n", dtype=tf.string + ) + ids = vocab.encode_tf(text) + recovered = vocab._decode_tf(ids) # pylint: disable=W0212 + + self.assertEqual(text.numpy().decode("utf-8"), recovered.numpy().decode("utf-8")) + + def test_tokenize_example(self): + vocab = self.vocab_cfg.instantiate() + newlines_replaced_with = "" + newlines_replaced_with_id = vocab.encode(newlines_replaced_with) + + # Test tokenize_example replaces newlines. + tokens = input_text.tokenize_example( + "Hello\n", sp_vocab=vocab, replace_newlines_with=newlines_replaced_with + ).numpy() + self.assertNestedAllClose( + np.array(vocab.encode("Hello") + newlines_replaced_with_id), tokens + ) + + def test_num_bytes(self): + vocab = self.vocab_cfg.instantiate() + newlines_replaced_with = "\n" + pad_id = vocab.pad_id + newline_id = vocab.encode("\n").pop() + newlines_replaced_with_id = vocab.encode(newlines_replaced_with).pop() + + # Test num_bytes computes expected value. + ids = tf.constant( + [vocab.eos_id, newlines_replaced_with_id, newline_id, pad_id, pad_id, pad_id], + dtype=tf.int32, + ) + self.assertEqual( + 3, + input_text.num_bytes( + ids, sp_vocab=vocab, newlines_replaced_with=newlines_replaced_with + ), + ) + + @parameterized.parameters( + dict( + packing_method=PackingMethodType.EOS_DELIM_MASK, + max_padding_fraction=1.0, # Always pad + ), + dict( + packing_method=PackingMethodType.EOS_DELIM_NO_MASK, + max_padding_fraction=1.0, # Always pad + ), + dict( + packing_method=PackingMethodType.EOS_DELIM_MASK, + max_padding_fraction=0.0, # Do not pad + ), + ) + def test_fake_text_lm_training_data( + self, packing_method: PackingMethodType, max_padding_fraction: float + ): + texts = [ + "hello world\n", + "hello moon\n", + ] + + # window_size > len(texts) to repeat the sentence. 18 tokens in total. + window_size = 3 + + # Pad the concatenated sequence to 20 tokens: + # Or, trim the sequence to 15 tokens: + batch_size, max_len = 2, 5 + + # Disable shuffling to make results interpretable. + shuffle_buffer_size = 0 + + # Test text_to_lm_training_input. + cfg = input_tf_data.Input.default_config().set( + name="test_input", + is_training=True, + source=config_for_function(make_ds_fn).set(texts=texts), + processor=config_for_function(text_to_lm_training_input).set( + vocab_cfg=self.vocab_cfg, + max_len=max_len, + replace_newlines_with="\n", + window_size=window_size, + max_padding_fraction=max_padding_fraction, + shuffle_buffer_size=shuffle_buffer_size, + packing_method=packing_method, + ), + batcher=config_for_function(input_tf_data.batch).set( + global_batch_size=batch_size, + prefetch_buffer_size=2, + pad_example_fn=input_tf_data.default_pad_example_fn, + ), + ) + + # Set TensorFlow seed. + tf.random.set_seed(123) + dataset = cfg.instantiate(parent=None) + for ix, batch in enumerate(dataset): + self.assertIsNotNone(batch) + batch = {k: v.tolist() for k, v in batch.items()} + if ix >= 10: + # Expect to be able to repeat forever. + break + + # Test lm_text_preprocessor. Expect same results. + cfg = input_tf_data.Input.default_config().set( + name="test_input", + is_training=True, + source=config_for_function(make_ds_fn).set( + texts=texts, + ), + processor=config_for_function(lm_text_preprocessor).set( + vocab_cfg=self.vocab_cfg, + max_sequence_length=max_len, + replace_newlines_with="", + window_size=window_size, + max_padding_fraction=max_padding_fraction, + shuffle_buffer_size=shuffle_buffer_size, + packing_method=packing_method, + ), + batcher=config_for_function(input_tf_data.batch).set( + global_batch_size=batch_size, + prefetch_buffer_size=2, + pad_example_fn=input_tf_data.default_pad_example_fn, + ), + ) + + # Reset TensorFlow seed. + tf.random.set_seed(123) + dataset = cfg.instantiate(parent=None) + for ix, batch in enumerate(dataset): + if ix >= 3: + break + batch = {k: v.tolist() for k, v in batch.items()} + + @parameterized.parameters( + ("How long is a piece of string?", "index"), + ("On the 20th of June", "not_index"), + ("Here we stand united", None), + ) + def test_eval_lm_processor_single_example(self, text, index_key): + max_len = 12 + processor = text_to_lm_eval_input( + vocab_cfg=self.vocab_cfg, + max_len=max_len, + replace_newlines_with="\n", + stride=None, + index_key="index", + ) + ds_fn = ( + config_for_function(make_ds_fn) + .set(is_training=False, texts=[text], repeat=1) + .instantiate() + ) + example = next(iter(processor(ds_fn()))) + for key in ["input_ids", "target_labels"]: + # Shape is as expected. + self.assertEqual((max_len,), example[key].numpy().shape) + self.assertTrue("target_num_bytes" in example) + # Index should have been passed through only for set value of `index_key`. + self.assertEqual(index_key == "index", index_key in example) + + input_ids, target_labels = example["input_ids"].numpy(), example["target_labels"].numpy() + self.assertEqual(128001, input_ids[0]) # EOS + non_padded_length = (target_labels == 128004).argmax() + self.assertEqual(128001, target_labels[non_padded_length - 1]) # EOS. + # The inputs should be one-off the labels. + self.assertNestedAllClose( + target_labels[: non_padded_length - 1], input_ids[1:non_padded_length] + )