-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
32 b #121
base: main
Are you sure you want to change the base?
32 b #121
Changes from 17 commits
b94e702
368abb8
c277d54
7c74d8b
53d61fe
514abb8
011113e
2577397
93637a1
784377d
eec7e10
bd5edee
f516f09
49264f5
4bb5d5c
1ff1371
4375612
20b9b08
7736198
8e0613f
5652953
323c786
4f676e2
d4e63fa
7c22386
f38bff4
3bf2440
ab5afcf
47f9545
ee6aa90
c656a41
7852e1e
b19e76d
a02dd95
6eaa5a3
b2a07de
9985d31
4cc6a62
1060499
fb2a274
1073613
c553b98
e49d4b7
9608482
4804004
fd4edb8
1f79446
072c616
6ba3e23
07cc66c
18e9a32
2150b36
c8cf403
d9cb6cf
5f2cf19
19c8758
9a12202
d5e6e2b
ea0acce
d2a00a7
016e426
a28ca37
1c33794
484d01c
275364c
54d5623
4644e6e
d7ed30e
0c47992
246eff6
b956e3f
f877907
58bef95
c84708f
56c4ab3
b5f3a86
3fbdeb0
b335cdf
30f8f59
e17e4b8
ba49cc4
25ede33
ac01e83
ddd61ac
973a26c
baf5700
b6762d8
d98f06d
4a68e9e
d81cd12
0a04034
5acc7eb
213b03e
b4994b0
3b84351
178d9ad
0b737aa
3e6f9f1
663d63a
496919b
a1854bd
f2de5f4
33c0f58
86afc43
2e45a79
e4e8fbb
146caaf
393a462
d39c59d
16983c4
5e4d04f
eba0418
5605001
7ce7efa
05aa94f
9c86bf9
e27b91d
52b9b77
f045eee
ddb3084
72e0ed1
d1d8dcb
a0700e8
0595cf8
74c6960
df46d5c
985785c
6cc9e99
6c31495
f47f6f5
db0df12
7f98496
be4e788
bfa6a8d
b1ad693
fde5f68
4264050
a7b4507
5fbc50e
7f6a6d0
f72cd46
a3d6672
7c573b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -460,19 +460,22 @@ def olmo2_13B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": | |
) | ||
|
||
@classmethod | ||
def olmo2_26B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": | ||
def olmo2_32B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": | ||
""" | ||
A 26B OLMo model config. | ||
A 32B OLMo model config. | ||
""" | ||
d_model = 5120 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a very narrow model then... are you sure about that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a clone of Qwen 32. The tradeoffs are, narrow d_model, wide FFN, GQA, lots of layers. |
||
return cls.llama_like( | ||
vocab_size=vocab_size, | ||
d_model=7168, | ||
n_layers=kwargs.pop("n_layers", 40), | ||
n_heads=kwargs.pop("n_heads", 56), | ||
d_model=d_model, | ||
n_layers=kwargs.pop("n_layers", 64), | ||
n_heads=kwargs.pop("n_heads", 40), | ||
n_kv_heads=kwargs.pop("n_kv_heads", 8), | ||
block_name=kwargs.pop("block_name", TransformerBlockType.reordered_norm), | ||
qk_norm=kwargs.pop("qk_norm", True), | ||
rope_theta=kwargs.pop("rope_theta", 500_000), | ||
hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 1024), | ||
hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 512), | ||
hidden_size_multiplier=kwargs.pop("hidden_size_multiplier", 27648 / (8 * d_model / 3)), | ||
layer_norm_eps=1e-6, | ||
**kwargs, | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,7 +130,7 @@ def build(self, trainer: "Trainer") -> Optional[Callback]: | |
eval_batch_size = ( | ||
self.eval_batch_size | ||
if self.eval_batch_size is not None | ||
else trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group) | ||
else 2 * trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: you could instead passed an updated evaluator callback in
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, but I think this is better. I think we can default to 2x the training batch size. It should always work. |
||
) | ||
dataset = self.eval_dataset.build() | ||
if not isinstance(dataset, NumpyPaddedFSLDataset): | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -15,15 +15,16 @@ | |||
TransformerDataParallelConfig, | ||||
) | ||||
from olmo_core.optim import AdamWConfig, OptimGroupOverride | ||||
from olmo_core.train import TrainerConfig | ||||
from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback | ||||
from olmo_core.train import TrainerConfig, Duration, DurationUnit | ||||
from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback, \ | ||||
DownstreamEvaluatorCallbackConfig | ||||
|
||||
log = logging.getLogger(__name__) | ||||
|
||||
|
||||
def build_model_config(common: CommonComponents) -> TransformerConfig: | ||||
compile = True | ||||
return TransformerConfig.olmo2_26B( | ||||
return TransformerConfig.olmo2_32B( | ||||
vocab_size=common.tokenizer.padded_vocab_size(), | ||||
compile=compile, | ||||
fused_ops=False, | ||||
|
@@ -52,20 +53,23 @@ def build_optim_config(common: CommonComponents) -> AdamWConfig: | |||
|
||||
|
||||
def build_trainer_config(common: CommonComponents) -> TrainerConfig: | ||||
project_name = "peteish32" | ||||
return ( | ||||
TrainerConfig( | ||||
save_folder=common.save_folder, | ||||
save_folder=f"gs://ai2-llm/checkpoints/{project_name}/", | ||||
dirkgr marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
rank_microbatch_size=4 * 4096, | ||||
save_overwrite=True, | ||||
metrics_collect_interval=10, | ||||
cancel_check_interval=1, | ||||
cancel_check_interval=10, | ||||
z_loss_multiplier=1e-5, | ||||
compile_loss=True, | ||||
fused_loss=True, | ||||
compile_loss=False, | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand the trepidation about the different loss implementations, but the way it was before was the most performant. This way will be slower and have a higher memory footprint. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have some certainty that this will do the right thing? What happens if we take the 13B from a late checkpoint and run it? |
||||
max_duration=Duration(int(6.5e12), DurationUnit.tokens) | ||||
) | ||||
.with_callback( | ||||
"checkpointer", | ||||
CheckpointerCallback( | ||||
save_interval=10_000, | ||||
save_interval=1000, | ||||
ephemeral_save_interval=250, | ||||
save_async=True, | ||||
), | ||||
|
@@ -75,7 +79,7 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig: | |||
CometCallback( | ||||
name=common.run_name, | ||||
workspace="ai2", | ||||
project="OLMo-core-26B", | ||||
project=project_name, | ||||
enabled=True, | ||||
cancel_check_interval=10, | ||||
), | ||||
|
@@ -85,10 +89,57 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig: | |||
WandBCallback( | ||||
name=common.run_name, | ||||
entity="ai2-llm", | ||||
project="OLMo-core-26B", | ||||
project=project_name, | ||||
enabled=False, | ||||
dirkgr marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
cancel_check_interval=10, | ||||
), | ||||
).with_callback( | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should just add this to the common callbacks.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know that we want these for everything. Default should probably be only the new, blessed ones. |
||||
"downstream_evaluator", | ||||
DownstreamEvaluatorCallbackConfig( | ||||
tasks=[ | ||||
# MMLU for backwards compatibility | ||||
"mmlu_stem_mc_5shot", | ||||
"mmlu_humanities_mc_5shot", | ||||
"mmlu_social_sciences_mc_5shot", | ||||
"mmlu_other_mc_5shot", | ||||
|
||||
# MMLU test | ||||
"mmlu_stem_mc_5shot_test", | ||||
"mmlu_humanities_mc_5shot_test", | ||||
"mmlu_social_sciences_mc_5shot_test", | ||||
"mmlu_other_mc_5shot_test", | ||||
|
||||
# Core 12 tasks for backwards compatibility | ||||
"arc_challenge", | ||||
"arc_easy", | ||||
"basic_arithmetic", | ||||
"boolq", | ||||
"commonsense_qa", | ||||
"copa", | ||||
"hellaswag", | ||||
"openbook_qa", | ||||
"piqa", | ||||
"sciq", | ||||
"social_iqa", | ||||
"winogrande", | ||||
|
||||
# Core 12 tasks 5-shot | ||||
"arc_challenge_rc_5shot", | ||||
"arc_easy_rc_5shot", | ||||
#"basic_arithmetic_rc_5shot", # doesn't exist | ||||
#"boolq_rc_5shot", # we don't like it | ||||
"csqa_rc_5shot", | ||||
#"copa_rc_5shot", # doesn't exist | ||||
"hellaswag_rc_5shot", | ||||
"openbookqa_rc_5shot", | ||||
"piqa_rc_5shot", | ||||
#"sciq_rc_5shot", # doesn't exist | ||||
"socialiqa_rc_5shot", | ||||
"winogrande_rc_5shot" | ||||
], | ||||
tokenizer=common.tokenizer, | ||||
eval_interval=1000, | ||||
), | ||||
) | ||||
) | ||||
|
||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our in-house triton CE loss was copied directly from the flash-attn repo, so I don't see the point of this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I took this back out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I want compiling and fused loss at the same time?