|
73 | 73 | from axlearn.experiments.text.gpt.fuji import offload_attention_proj_policy
|
74 | 74 | from axlearn.experiments.trainer_config_utils import TrainerConfigFn
|
75 | 75 |
|
76 |
| -MODEL_SIZES = ("test", "Switch-Base", "Switch-Large", "Switch-XXL", "Mistral-8x7B", "Mistral-8x7B-toy", "Mistral-8x20B") |
| 76 | +MODEL_SIZES = ("test", "Switch-Base", "Switch-Large", "Switch-XXL", "Mistral-8x7B", "Mistral-toy", "Mistral-8x20B") |
77 | 77 |
|
78 | 78 | NUM_EXPERTS = {
|
79 | 79 | "test": 8,
|
80 | 80 | "Switch-Base": 128,
|
81 | 81 | "Switch-Large": 128,
|
82 | 82 | "Switch-XXL": 64,
|
83 | 83 | "Mistral-8x7B": 8,
|
84 |
| - "Mistral-8x7B-toy": 8, |
| 84 | + "Mistral-toy": 8, |
85 | 85 | "Mistral-8x20B": 8,
|
86 | 86 | }
|
87 | 87 |
|
|
93 | 93 | "Switch-Base": 8192,
|
94 | 94 | "Switch-Large": 8192,
|
95 | 95 | "Switch-XXL": 8192,
|
96 |
| - "Mistral-8x7B-toy": 256, |
| 96 | + "Mistral-toy": 256, |
97 | 97 | "Mistral-8x7B": 4096,
|
98 | 98 | "Mistral-8x20B": 4096,
|
99 | 99 | }
|
@@ -445,14 +445,14 @@ def get_trainer_kwargs(
|
445 | 445 | # TODO(kelvin-zou): not verified with real job.
|
446 | 446 | mesh_shape=mesh_shape_from_axes(fsdp=-1, expert=16, model=8),
|
447 | 447 | )
|
448 |
| - elif model_size in ["Mistral-8x7B", "Mistral-8x7B-toy"]: |
| 448 | + elif model_size in ["Mistral-8x7B", "Mistral-toy"]: |
449 | 449 | # Num of parameters: 47B.
|
450 | 450 | ffn_layer_types = get_ffn_layer_types()
|
451 | 451 | neuron_mesh = mesh_shape_from_axes(fsdp=-1, model=4)
|
452 | 452 | trainer_kwargs = dict(
|
453 | 453 | model_kwargs=dict(
|
454 | 454 | num_layers=int(os.getenv("NUM_LAYERS", 4)),
|
455 |
| - hidden_dim=32 * 32 if model_size == "Mistral-8x7B-toy" else 32 * 128, |
| 455 | + hidden_dim=32 * 32 if model_size == "Mistral-toy" else 32 * 128, |
456 | 456 | ffn_dim=scaled_hidden_dim(scale=3.5, round_up_to_multiples_of=128),
|
457 | 457 | num_heads=32,
|
458 | 458 | num_kv_heads=8,
|
|
0 commit comments