Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

32 b #121

Draft
wants to merge 143 commits into
base: main
Choose a base branch
from
Draft

32 b #121

Show file tree
Hide file tree
Changes from 122 commits
Commits
Show all changes
143 commits
Select commit Hold shift + click to select a range
b94e702
Save more often
dirkgr Dec 8, 2024
368abb8
Don't check for cancelation all the time
dirkgr Dec 8, 2024
c277d54
Make sure we use the same CE loss that we used for the 13B
dirkgr Dec 8, 2024
7c74d8b
We're going to 5T!
dirkgr Dec 8, 2024
53d61fe
We can live with a bigger eval batch size.
dirkgr Dec 8, 2024
514abb8
Add MMLU downstream eval
dirkgr Dec 9, 2024
011113e
Module isn't callable
dirkgr Dec 9, 2024
2577397
Qwen-ish
dirkgr Dec 9, 2024
93637a1
Make model bigger
dirkgr Dec 9, 2024
784377d
It's now a 32B.
dirkgr Dec 10, 2024
eec7e10
6T tokens
dirkgr Dec 10, 2024
bd5edee
Official save folder
dirkgr Dec 10, 2024
f516f09
6.5T tokens
dirkgr Dec 10, 2024
49264f5
Merge remote-tracking branch 'origin/main' into 32B
dirkgr Dec 10, 2024
4bb5d5c
Merged
dirkgr Dec 10, 2024
1ff1371
Change project name and location
dirkgr Dec 10, 2024
4375612
Revert "Merged"
dirkgr Dec 10, 2024
20b9b08
Revert "Module isn't callable"
dirkgr Dec 10, 2024
7736198
Revert "Make sure we use the same CE loss that we used for the 13B"
dirkgr Dec 10, 2024
8e0613f
We still want it fused!
dirkgr Dec 10, 2024
5652953
One-in-two activation checkpointing
dirkgr Dec 10, 2024
323c786
Merge remote-tracking branch 'origin/main' into 32B
dirkgr Dec 10, 2024
4f676e2
Smaller microbatch
dirkgr Dec 10, 2024
d4e63fa
Wrap 3 in 4 blocks
dirkgr Dec 10, 2024
7c22386
Don't compile the loss.
dirkgr Dec 10, 2024
f38bff4
Turn off broken eval
dirkgr Dec 11, 2024
3bf2440
Go back to mbsz of 4
dirkgr Dec 11, 2024
ab5afcf
Set drop_last for DownstreamEvaluator to False
2015aroras Dec 11, 2024
47f9545
Bring back Copa now that we have Shane's fix
dirkgr Dec 11, 2024
ee6aa90
Merge remote-tracking branch 'origin/32B' into 32B
dirkgr Dec 11, 2024
c656a41
Check if beaker loading issues are due to beaker changes by updating …
2015aroras Dec 11, 2024
7852e1e
Try hsdp with 2 nodes per replica
2015aroras Dec 11, 2024
b19e76d
Revert "Try hsdp with 2 nodes per replica"
2015aroras Dec 11, 2024
a02dd95
Try activation checkpointing 3 in 4
2015aroras Dec 12, 2024
6eaa5a3
Try activation checkpointing 3 in 4 + all feedforwards checkpointed
2015aroras Dec 12, 2024
b2a07de
Decrease microbatch size
2015aroras Dec 13, 2024
9985d31
Try activation checkpointing on just feed forwards
2015aroras Dec 13, 2024
4cc6a62
Fix name
dirkgr Dec 16, 2024
1060499
Try to run with hybrid sharding.
dirkgr Dec 16, 2024
fb2a274
More batch
dirkgr Dec 16, 2024
1073613
Revert "More batch"
dirkgr Dec 16, 2024
c553b98
There is something wrong with how the `common` object is set up.
dirkgr Dec 16, 2024
e49d4b7
We need a less sharded checkpoint and I guess this is the only way we…
dirkgr Dec 16, 2024
9608482
Revert "We need a less sharded checkpoint and I guess this is the onl…
dirkgr Dec 16, 2024
4804004
Async checkpointer may have problems with large checkpoints?
dirkgr Dec 16, 2024
fd4edb8
For loading checkpoints, it seems we need a longer timeout
dirkgr Dec 16, 2024
1f79446
Revert "Async checkpointer may have problems with large checkpoints?"
dirkgr Dec 16, 2024
072c616
Flight to safety
dirkgr Dec 16, 2024
6ba3e23
Increase microbatch size up to 2 * 4096
2015aroras Dec 17, 2024
07cc66c
Watching the 32B in a notebook
dirkgr Dec 18, 2024
18e9a32
Merge branch '32B' of https://github.com/allenai/OLMo-core into 32B
dirkgr Dec 18, 2024
2150b36
Merge branch 'main' into 32B
2015aroras Dec 19, 2024
c8cf403
Enable HSDP with pre-downloading
2015aroras Dec 19, 2024
d9cb6cf
Turn off hsdp
2015aroras Dec 19, 2024
5f2cf19
Revert "Turn off hsdp"
2015aroras Dec 19, 2024
19c8758
Add option to set thread_count
2015aroras Dec 19, 2024
9a12202
Run formatter
2015aroras Dec 19, 2024
d5e6e2b
Limit thread count
2015aroras Dec 19, 2024
ea0acce
Decrease microbatch size
2015aroras Dec 19, 2024
d2a00a7
Increase microbatch size, increase activation checkpointing
2015aroras Dec 19, 2024
016e426
Decrease microbatch size
2015aroras Dec 20, 2024
a28ca37
Decrease thread_count
2015aroras Dec 20, 2024
1c33794
Thread count 1
2015aroras Dec 20, 2024
484d01c
Back to FSDP
2015aroras Dec 20, 2024
275364c
Back to HSDP, but with less replicas
2015aroras Dec 20, 2024
54d5623
Merge branch 'main' into 32B
2015aroras Dec 20, 2024
4644e6e
Microbatch size back to 1
2015aroras Dec 20, 2024
d7ed30e
Revert "Microbatch size back to 1"
2015aroras Dec 20, 2024
0c47992
Back to FSDP
2015aroras Dec 20, 2024
246eff6
Revert "Back to FSDP"
2015aroras Dec 20, 2024
b956e3f
Enable NCCL debug
2015aroras Dec 20, 2024
f877907
More debug info
2015aroras Dec 20, 2024
58bef95
Merge branch 'main' into 32B
2015aroras Dec 20, 2024
c84708f
Disable pre_download, set higher thread count
2015aroras Dec 20, 2024
56c4ab3
FSDP with AC of selected ops
2015aroras Dec 20, 2024
b5f3a86
Back to AC of just feedforward layers
2015aroras Dec 21, 2024
3fbdeb0
Add new inloop evals
2015aroras Dec 21, 2024
b335cdf
Turn off NCCL debug
2015aroras Dec 21, 2024
30f8f59
Merge branch 'main' into 32B
2015aroras Dec 21, 2024
e17e4b8
Make checkpoint writing respect thread count config
2015aroras Dec 22, 2024
ba49cc4
Add skip step optimizer changes
2015aroras Dec 22, 2024
25ede33
Update 32B config with skip step adamw
2015aroras Dec 22, 2024
ac01e83
Try fix skip step optimizer
2015aroras Dec 22, 2024
ddd61ac
Try manual _std_mean impl
2015aroras Dec 22, 2024
973a26c
Add skip step fixes
2015aroras Dec 22, 2024
baf5700
Have separate save and load thread counts
2015aroras Dec 22, 2024
b6762d8
Decrease threads used for saving
2015aroras Dec 22, 2024
d98f06d
Skipped steps and automatic spike analysis
dirkgr Dec 22, 2024
4a68e9e
Use compile=True for optimizer
2015aroras Dec 22, 2024
d81cd12
Make gcs upload pass generation
2015aroras Dec 23, 2024
0a04034
Update CHANGELOG
2015aroras Dec 23, 2024
5acc7eb
Run formatter
2015aroras Dec 23, 2024
213b03e
Make generation 0 when object does not exist
2015aroras Dec 23, 2024
b4994b0
Merge branch 'shanea/fix-upload-retries' into 32B
2015aroras Dec 23, 2024
3b84351
Run formatting
2015aroras Dec 23, 2024
178d9ad
Remove unneeded import
2015aroras Dec 23, 2024
0b737aa
Add missing reload
2015aroras Dec 23, 2024
3e6f9f1
Updated notebook
dirkgr Dec 23, 2024
663d63a
Updated dashboard
dirkgr Dec 24, 2024
496919b
Update the notebook
dirkgr Dec 24, 2024
a1854bd
Updated notebook
dirkgr Dec 27, 2024
f2de5f4
Retry on bad request
dirkgr Dec 28, 2024
33c0f58
Add some more retries
dirkgr Dec 28, 2024
86afc43
Updated the notebook
dirkgr Dec 29, 2024
2e45a79
Update the dashboard
dirkgr Dec 30, 2024
e4e8fbb
Fix the way we use the step in the optimizer
dirkgr Dec 31, 2024
146caaf
Dashboard update
dirkgr Dec 31, 2024
393a462
Update dashboard
dirkgr Jan 3, 2025
d39c59d
New report
dirkgr Jan 6, 2025
16983c4
Dashboard update
dirkgr Jan 7, 2025
5e4d04f
No more ephemeral checkpoints
dirkgr Jan 8, 2025
eba0418
Don't eval so much
dirkgr Jan 8, 2025
5605001
When you wait on someone, you bring them water.
dirkgr Jan 8, 2025
7ce7efa
Updating the dashboard
dirkgr Jan 8, 2025
05aa94f
Reorder ranks in GCP
dirkgr Jan 9, 2025
9c86bf9
Rank 0 needs to remain rank 0
dirkgr Jan 9, 2025
e27b91d
Slightly less checkpointing
dirkgr Jan 9, 2025
52b9b77
Revert "Slightly less checkpointing"
dirkgr Jan 9, 2025
f045eee
Turn off failure propagation to make slack notifier work better
2015aroras Jan 13, 2025
ddb3084
New dashboard
dirkgr Jan 14, 2025
72e0ed1
Merge branch '32B' of https://github.com/allenai/OLMo-core into 32B
dirkgr Jan 14, 2025
d1d8dcb
hopefully make GCS client calls more robust
epwalsh Jan 14, 2025
a0700e8
Catch user exceptions as well as system exceptions when training fails
2015aroras Jan 15, 2025
0595cf8
Revert "Catch user exceptions as well as system exceptions when train…
2015aroras Jan 15, 2025
74c6960
Dashboard
dirkgr Jan 16, 2025
df46d5c
Merge remote-tracking branch 'origin/32B' into 32B
dirkgr Jan 16, 2025
985785c
Suppress Google checksum warnings
2015aroras Jan 16, 2025
6cc9e99
Setup kernel cache for PyTorch
2015aroras Jan 16, 2025
6c31495
Dashboard
dirkgr Jan 17, 2025
f47f6f5
minor clean up
epwalsh Jan 17, 2025
db0df12
Add profiler
dirkgr Jan 20, 2025
7f98496
Dashboard
dirkgr Jan 20, 2025
be4e788
clean up rank reordering
epwalsh Jan 21, 2025
bfa6a8d
move script to launch module so it's available in package
epwalsh Jan 21, 2025
b1ad693
remove old
epwalsh Jan 21, 2025
fde5f68
fix merge conflicts
epwalsh Jan 21, 2025
4264050
clean up
epwalsh Jan 21, 2025
a7b4507
Merge branch 'main' into 32B
epwalsh Jan 21, 2025
5fbc50e
throttle uploads
epwalsh Jan 21, 2025
7f6a6d0
Add annealing config
epwalsh Jan 22, 2025
f72cd46
update annealing config for dolmino mix
epwalsh Jan 23, 2025
a3d6672
minor refactor
epwalsh Jan 23, 2025
7c573b2
update mix
epwalsh Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`.
- Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint.
- Added a callback for sending Slack notifications.
- Added `SkipStepAdamW` optimizer.

### Changed

- Changed storage of shared shard state in sharded checkpoints from smallest shard to lowest rank (normally 0).
- Changed underlying AdamW implementation.

### Fixed

- Added missing `weights_only=False` argument to fix loading train checkpoints with newer versions of PyTorch.
- Fixed bug where GCS upload does not retry on transient failures.

## [v1.7.0](https://github.com/allenai/OLMo-core/releases/tag/v1.7.0) - 2024-11-27

Expand Down
15 changes: 11 additions & 4 deletions src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def save_state_dict(
state_dict: Dict[str, Any],
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
):
"""
Save an arbitrary state dictionary to a distributed format that can loaded again with
Expand All @@ -80,7 +81,7 @@ def save_state_dict(
dir = _prepare_env_for_save(dir, process_group=process_group, save_overwrite=save_overwrite)
dist_cp.state_dict_saver.save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count),
process_group=process_group,
)

Expand All @@ -93,6 +94,7 @@ def save_model_and_optim_state(
*,
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
) -> None:
"""
Save model and optimizer state dictionaries. The model state can be a sharded model, in which
Expand Down Expand Up @@ -123,7 +125,7 @@ def save_model_and_optim_state(
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
dist_cp.state_dict_saver.save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count),
process_group=process_group,
planner=planner,
)
Expand All @@ -137,6 +139,7 @@ def async_save_model_and_optim_state(
*,
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
) -> Future[None]:
"""
An async version of :func:`save_model_and_optim_state()`.
Expand All @@ -148,7 +151,7 @@ def async_save_model_and_optim_state(
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
return dist_cp.state_dict_saver.async_save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count),
process_group=process_group,
planner=planner,
)
Expand All @@ -164,6 +167,7 @@ def load_model_and_optim_state(
key_mapping: Optional[Dict[str, str]] = None,
pre_download: bool = False,
work_dir: Optional[PathOrStr] = None,
thread_count: Optional[int] = None,
):
"""
Load model and optimizer state in-place from a checkpoint saved via :func:`save_model_and_optim_state()`.
Expand Down Expand Up @@ -201,10 +205,13 @@ def load_model_and_optim_state(
This dictionary should map current keys to keys in the checkpoint to be loaded.
:param pre_download: Download and cache relevant remote checkpoint files before trying to read from them.
:param work_dir: A working directory for caching files/directories.
:param thread_count: Set the number of threads used for certain operations.
"""
dir = normalize_path(dir)
state_dict = _prepare_state_dict(model, optim, process_group=process_group)
reader = RemoteFileSystemReader(dir, pre_download=pre_download, work_dir=work_dir)
reader = RemoteFileSystemReader(
dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir
)

if key_mapping is not None:
metadata = reader.read_metadata()
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut
"enp6s0,enp7s0,enp13s0,enp14s0,enp134s0,enp135s0,enp141s0,enp142s0",
)
set_env_var("NCCL_SOCKET_IFNAME", "enp0s12")
set_env_var("NCCL_DEBUG_SUBSYS", "INIT,NET")

if backend_supports_cuda(backend):
# Set CUDA device.
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/internal/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def build_launch_config(
# Setup python environment.
"conda shell.bash activate base",
"pip install -e '.[all]'",
"pip install --upgrade beaker-py",
# Quickly try a new version of PyTorch like this
# "pip install --upgrade --pre torch==2.6.0.dev20241112+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121",
"pip freeze",
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def build_common_components(
root_dir=root_dir,
cmd=[script, cmd_to_launch, run_name, cluster, *overrides],
cluster=cluster,
nccl_debug=False,
)

beaker_user = get_beaker_username()
Expand Down
43 changes: 31 additions & 12 deletions src/olmo_core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,16 +532,25 @@ def _get_gcs_client():


def _gcs_is_retriable(exc: Exception) -> bool:
from google.api_core.exceptions import BadRequest
from google.api_core.retry import if_transient_error

return if_transient_error(exc) or isinstance(exc, requests.exceptions.Timeout)
return (
if_transient_error(exc)
or isinstance(exc, requests.exceptions.Timeout)
or isinstance(exc, BadRequest) # Weird choice, but Google throws this transiently
)


def _get_gcs_retry():
from google.api_core.retry import Retry

return Retry(
predicate=_gcs_is_retriable, initial=1.0, maximum=10.0, multiplier=2.0, timeout=600.0
predicate=_gcs_is_retriable, # NOTE: it appears google might ignore this
initial=1.0,
maximum=10.0,
multiplier=2.0,
timeout=600.0,
)


Expand All @@ -554,7 +563,7 @@ def _get_gcs_conditional_retry():
return ConditionalRetryPolicy(_get_gcs_retry(), is_generation_specified, ["query_params"])


@retriable()
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_file_size(bucket_name: str, key: str) -> int:
from google.api_core.exceptions import NotFound

Expand All @@ -569,35 +578,45 @@ def _gcs_file_size(bucket_name: str, key: str) -> int:
return blob.size


@retriable()
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
from google.api_core.exceptions import NotFound

storage_client = _get_gcs_client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
try:
blob.reload()
blob.reload(retry=_get_gcs_retry())
except NotFound:
raise FileNotFoundError(f"gs://{bucket_name}/{key}")
return blob.download_as_bytes(
start=bytes_start, end=bytes_start + num_bytes - 1, retry=_get_gcs_retry()
)


@retriable()
@retriable(retry_condition=_gcs_is_retriable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This general approach sort of blows up our retry time from 10 mins to 30 mins. Sort of not a fan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But at least it looks like it works.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could always reduce the deadline/timeout

def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
storage_client = _get_gcs_client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
if not save_overwrite and blob.exists():
raise FileExistsError(
f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)
blob.upload_from_filename(source, retry=_get_gcs_conditional_retry())

generation: int = 0
if blob.exists(retry=_get_gcs_retry()):
if not save_overwrite:
raise FileExistsError(
f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)

blob.reload(retry=_get_gcs_retry())
assert blob.generation is not None
generation = blob.generation

@retriable()
blob.upload_from_filename(
source, if_generation_match=generation, retry=_get_gcs_conditional_retry()
)


@retriable(retry_condition=_gcs_is_retriable)
def _gcs_clear_directory(bucket_name: str, prefix: str):
from google.api_core.exceptions import NotFound

Expand Down
10 changes: 9 additions & 1 deletion src/olmo_core/launch/beaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,14 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec:
]

if torchrun:
entrypoint_script.append(
"export BEAKER_REPLICA_RANK=$("
"python src/scripts/reorder_ranks_in_gcp.py "
"${BEAKER_REPLICA_RANK} "
"${BEAKER_REPLICA_COUNT} "
"${BEAKER_LEADER_REPLICA_HOSTNAME}"
")"
)
entrypoint_script.append(" ".join(self._get_torchrun_cmd()) + ' "$@"')
else:
entrypoint_script.append('python "$@"')
Expand All @@ -341,7 +349,7 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec:
leader_selection=self.num_nodes > 1,
host_networking=self.num_nodes > 1
or any(["augusta" in cluster for cluster in self.clusters]),
propagate_failure=True if self.num_nodes > 1 else None,
propagate_failure=False,
propagate_preemption=True if self.num_nodes > 1 else None,
synchronized_start_timeout="90m" if self.num_nodes > 1 else None,
resources=TaskResources(gpu_count=self.num_gpus, shared_memory="10GiB"),
Expand Down
15 changes: 9 additions & 6 deletions src/olmo_core/nn/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,19 +460,22 @@ def olmo2_13B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
)

@classmethod
def olmo2_26B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
def olmo2_32B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
"""
A 26B OLMo model config.
A 32B OLMo model config.
"""
d_model = 5120
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a very narrow model then... are you sure about that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a clone of Qwen 32. The tradeoffs are, narrow d_model, wide FFN, GQA, lots of layers.

return cls.llama_like(
vocab_size=vocab_size,
d_model=7168,
n_layers=kwargs.pop("n_layers", 40),
n_heads=kwargs.pop("n_heads", 56),
d_model=d_model,
n_layers=kwargs.pop("n_layers", 64),
n_heads=kwargs.pop("n_heads", 40),
n_kv_heads=kwargs.pop("n_kv_heads", 8),
block_name=kwargs.pop("block_name", TransformerBlockType.reordered_norm),
qk_norm=kwargs.pop("qk_norm", True),
rope_theta=kwargs.pop("rope_theta", 500_000),
hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 1024),
hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 512),
hidden_size_multiplier=kwargs.pop("hidden_size_multiplier", 27648 / (8 * d_model / 3)),
layer_norm_eps=1e-6,
**kwargs,
)
Expand Down
5 changes: 4 additions & 1 deletion src/olmo_core/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .adam import AdamConfig
from .adamw import AdamWConfig
from .adamw import AdamW, AdamWConfig, SkipStepAdamW, SkipStepAdamWConfig
from .config import OptimConfig, OptimGroupOverride
from .lion import Lion, LionConfig, SkipStepLion, SkipStepLionConfig
from .scheduler import (
Expand All @@ -19,6 +19,9 @@
"SkipStepOptimizer",
"AdamWConfig",
"AdamConfig",
"AdamW",
"SkipStepAdamWConfig",
"SkipStepAdamW",
"LionConfig",
"Lion",
"SkipStepLionConfig",
Expand Down
Loading