Skip to content

Commit 17aced8

Browse files
committed
Make smallest model named toy
1 parent 7accada commit 17aced8

File tree

1 file changed

+5
-5
lines changed
  • axlearn/experiments/text/gpt

1 file changed

+5
-5
lines changed

axlearn/experiments/text/gpt/envy.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,15 @@
7373
from axlearn.experiments.text.gpt.fuji import offload_attention_proj_policy
7474
from axlearn.experiments.trainer_config_utils import TrainerConfigFn
7575

76-
MODEL_SIZES = ("test", "Switch-Base", "Switch-Large", "Switch-XXL", "Mistral-8x7B", "Mistral-8x7B-toy", "Mistral-8x20B")
76+
MODEL_SIZES = ("test", "Switch-Base", "Switch-Large", "Switch-XXL", "Mistral-8x7B", "Mistral-toy", "Mistral-8x20B")
7777

7878
NUM_EXPERTS = {
7979
"test": 8,
8080
"Switch-Base": 128,
8181
"Switch-Large": 128,
8282
"Switch-XXL": 64,
8383
"Mistral-8x7B": 8,
84-
"Mistral-8x7B-toy": 8,
84+
"Mistral-toy": 8,
8585
"Mistral-8x20B": 8,
8686
}
8787

@@ -93,7 +93,7 @@
9393
"Switch-Base": 8192,
9494
"Switch-Large": 8192,
9595
"Switch-XXL": 8192,
96-
"Mistral-8x7B-toy": 256,
96+
"Mistral-toy": 256,
9797
"Mistral-8x7B": 4096,
9898
"Mistral-8x20B": 4096,
9999
}
@@ -445,14 +445,14 @@ def get_trainer_kwargs(
445445
# TODO(kelvin-zou): not verified with real job.
446446
mesh_shape=mesh_shape_from_axes(fsdp=-1, expert=16, model=8),
447447
)
448-
elif model_size in ["Mistral-8x7B", "Mistral-8x7B-toy"]:
448+
elif model_size in ["Mistral-8x7B", "Mistral-toy"]:
449449
# Num of parameters: 47B.
450450
ffn_layer_types = get_ffn_layer_types()
451451
neuron_mesh = mesh_shape_from_axes(fsdp=-1, model=4)
452452
trainer_kwargs = dict(
453453
model_kwargs=dict(
454454
num_layers=int(os.getenv("NUM_LAYERS", 4)),
455-
hidden_dim=32 * 32 if model_size == "Mistral-8x7B-toy" else 32 * 128,
455+
hidden_dim=32 * 32 if model_size == "Mistral-toy" else 32 * 128,
456456
ffn_dim=scaled_hidden_dim(scale=3.5, round_up_to_multiples_of=128),
457457
num_heads=32,
458458
num_kv_heads=8,

0 commit comments

Comments
 (0)