Skip to content

Commit e42b6f7

Browse files
committed
Add llama 3 tokenizer
1 parent 420ed7a commit e42b6f7

File tree

107 files changed

+7445
-266
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+7445
-266
lines changed

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.
137137
model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
138138
model.decoder.emb.token_emb.param_partition_spec[0]: None
139139
model.decoder.emb.token_emb.param_partition_spec[1]: 'model'
140-
model.decoder.eos_token_id: 1
140+
model.decoder.eos_token_id: 128001
141141
model.decoder.klass: 'axlearn.common.decoder.Decoder'
142142
model.decoder.logits_partition_spec[0][0]: 'data'
143143
model.decoder.logits_partition_spec[0][1]: 'expert'
@@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout'
148148
model.decoder.output_norm.eps: 1e-05
149149
model.decoder.output_norm.forward_dtype: None
150150
model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm'
151-
model.decoder.pad_token_id: 0
151+
model.decoder.pad_token_id: 128004
152152
model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer'
153153
model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu'
154154
model.decoder.transformer.layer.feed_forward.activation[1]: 'linear'
@@ -267,7 +267,7 @@ model.decoder.transformer.num_layers: 16
267267
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
268268
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
269269
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
270-
model.decoder.vocab_size: 128256
270+
model.decoder.vocab_size: 131072
271271
model.dtype: 'jax.numpy.float32'
272272
model.klass: 'axlearn.common.causal_lm.Model'
273273
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-single-host_init.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
1+
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
22
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
33
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
44
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host.txt

Lines changed: 287 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
2+
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
3+
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
4+
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
5+
decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0)
6+
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
7+
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
8+
decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
9+
decoder/output_norm/scale: constant(1.0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
====================weight_decay_scale root.optimizer====================
2+
decoder/emb/token_emb/weight: 1
3+
decoder/output_norm/scale: 1
4+
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1
5+
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1
6+
decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1
7+
decoder/transformer/repeat/layer/feed_forward/norm/scale: 1
8+
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1
9+
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1
10+
decoder/transformer/repeat/layer/self_attention/norm/scale: 1

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken.txt

Lines changed: 287 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
2+
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
3+
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
4+
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
5+
decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0)
6+
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
7+
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
8+
decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
9+
decoder/output_norm/scale: constant(1.0)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
====================weight_decay_scale root.optimizer====================
2+
decoder/emb/token_emb/weight: 1
3+
decoder/output_norm/scale: 1
4+
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1
5+
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1
6+
decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1
7+
decoder/transformer/repeat/layer/feed_forward/norm/scale: 1
8+
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1
9+
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1
10+
decoder/transformer/repeat/layer/self_attention/norm/scale: 1

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.
137137
model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
138138
model.decoder.emb.token_emb.param_partition_spec[0]: None
139139
model.decoder.emb.token_emb.param_partition_spec[1]: 'model'
140-
model.decoder.eos_token_id: 1
140+
model.decoder.eos_token_id: 128001
141141
model.decoder.klass: 'axlearn.common.decoder.Decoder'
142142
model.decoder.logits_partition_spec[0][0]: 'data'
143143
model.decoder.logits_partition_spec[0][1]: 'expert'
@@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout'
148148
model.decoder.output_norm.eps: 1e-05
149149
model.decoder.output_norm.forward_dtype: None
150150
model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm'
151-
model.decoder.pad_token_id: 0
151+
model.decoder.pad_token_id: 128004
152152
model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer'
153153
model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu'
154154
model.decoder.transformer.layer.feed_forward.activation[1]: 'linear'
@@ -267,7 +267,7 @@ model.decoder.transformer.num_layers: 16
267267
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
268268
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
269269
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
270-
model.decoder.vocab_size: 128256
270+
model.decoder.vocab_size: 131072
271271
model.dtype: 'jax.numpy.float32'
272272
model.klass: 'axlearn.common.causal_lm.Model'
273273
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'

axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash_init.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
1+
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
22
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
33
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
44
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())

0 commit comments

Comments
 (0)