Skip to content

Commit 8dfef28

Browse files
committed
Add llama 3 tokenizer
1 parent 420ed7a commit 8dfef28

File tree

104 files changed

+7411
-232
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

104 files changed

+7411
-232
lines changed

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ model.decoder.transformer.num_layers: 16
267267
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
268268
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
269269
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
270-
model.decoder.vocab_size: 128256
270+
model.decoder.vocab_size: 131072
271271
model.dtype: 'jax.numpy.float32'
272272
model.klass: 'axlearn.common.causal_lm.Model'
273273
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_init.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
1+
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
22
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
33
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
44
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host.txt

Lines changed: 287 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
2+
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
3+
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
4+
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
5+
decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0)
6+
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
7+
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
8+
decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
9+
decoder/output_norm/scale: constant(1.0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
====================weight_decay_scale root.optimizer====================
2+
decoder/emb/token_emb/weight: 1
3+
decoder/output_norm/scale: 1
4+
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1
5+
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1
6+
decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1
7+
decoder/transformer/repeat/layer/feed_forward/norm/scale: 1
8+
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1
9+
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1
10+
decoder/transformer/repeat/layer/self_attention/norm/scale: 1

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken.txt

Lines changed: 287 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
2+
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
3+
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
4+
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
5+
decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0)
6+
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
7+
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
8+
decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
9+
decoder/output_norm/scale: constant(1.0)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
====================weight_decay_scale root.optimizer====================
2+
decoder/emb/token_emb/weight: 1
3+
decoder/output_norm/scale: 1
4+
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1
5+
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1
6+
decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1
7+
decoder/transformer/repeat/layer/feed_forward/norm/scale: 1
8+
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1
9+
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1
10+
decoder/transformer/repeat/layer/self_attention/norm/scale: 1

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ model.decoder.transformer.num_layers: 16
267267
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
268268
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
269269
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
270-
model.decoder.vocab_size: 128256
270+
model.decoder.vocab_size: 131072
271271
model.dtype: 'jax.numpy.float32'
272272
model.klass: 'axlearn.common.causal_lm.Model'
273273
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_init.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
1+
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
22
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
33
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
44
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())

0 commit comments

Comments
 (0)