Skip to content

Commit

Permalink
pull fixes from 32B branch (#139)
Browse files Browse the repository at this point in the history
This PR pulls the general important changes in from #121.
  • Loading branch information
epwalsh authored Jan 21, 2025
1 parent 48abe8c commit 7633461
Show file tree
Hide file tree
Showing 18 changed files with 294 additions and 142 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`.
- Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint.
- Added a callback for sending Slack notifications.
- Added `SkipStepAdamW` optimizer.
- The trainer can load model-only checkpoints now.

### Changed
Expand All @@ -25,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Added missing `weights_only=False` argument to fix loading train checkpoints with newer versions of PyTorch.
- Fixed bug where GCS upload does not retry on transient failures.

## [v1.7.0](https://github.com/allenai/OLMo-core/releases/tag/v1.7.0) - 2024-11-27

Expand Down
15 changes: 11 additions & 4 deletions src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def save_state_dict(
state_dict: Dict[str, Any],
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
):
"""
Save an arbitrary state dictionary to a distributed format that can loaded again with
Expand All @@ -80,7 +81,7 @@ def save_state_dict(
dir = _prepare_env_for_save(dir, process_group=process_group, save_overwrite=save_overwrite)
dist_cp.state_dict_saver.save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count),
process_group=process_group,
)

Expand All @@ -93,6 +94,7 @@ def save_model_and_optim_state(
*,
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
) -> None:
"""
Save model and optimizer state dictionaries. The model state can be a sharded model, in which
Expand Down Expand Up @@ -123,7 +125,7 @@ def save_model_and_optim_state(
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
dist_cp.state_dict_saver.save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count),
process_group=process_group,
planner=planner,
)
Expand All @@ -137,6 +139,7 @@ def async_save_model_and_optim_state(
*,
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
) -> Future[None]:
"""
An async version of :func:`save_model_and_optim_state()`.
Expand All @@ -148,7 +151,7 @@ def async_save_model_and_optim_state(
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
return dist_cp.state_dict_saver.async_save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count),
process_group=process_group,
planner=planner,
)
Expand All @@ -164,6 +167,7 @@ def load_model_and_optim_state(
key_mapping: Optional[Dict[str, str]] = None,
pre_download: bool = False,
work_dir: Optional[PathOrStr] = None,
thread_count: Optional[int] = None,
):
"""
Load model and optimizer state in-place from a checkpoint saved via :func:`save_model_and_optim_state()`.
Expand Down Expand Up @@ -201,10 +205,13 @@ def load_model_and_optim_state(
This dictionary should map current keys to keys in the checkpoint to be loaded.
:param pre_download: Download and cache relevant remote checkpoint files before trying to read from them.
:param work_dir: A working directory for caching files/directories.
:param thread_count: Set the number of threads used for certain operations.
"""
dir = normalize_path(dir)
state_dict = _prepare_state_dict(model, optim, process_group=process_group)
reader = RemoteFileSystemReader(dir, pre_download=pre_download, work_dir=work_dir)
reader = RemoteFileSystemReader(
dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir
)

if key_mapping is not None:
metadata = reader.read_metadata()
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut
"enp6s0,enp7s0,enp13s0,enp14s0,enp134s0,enp135s0,enp141s0,enp142s0",
)
set_env_var("NCCL_SOCKET_IFNAME", "enp0s12")
set_env_var("NCCL_DEBUG_SUBSYS", "INIT,NET")

if backend_supports_cuda(backend):
# Set CUDA device.
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/internal/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def build_launch_config(
# Setup python environment.
"conda shell.bash activate base",
"pip install -e '.[all]'",
"pip install --upgrade beaker-py",
# Quickly try a new version of PyTorch like this
# "pip install --upgrade --pre torch==2.6.0.dev20241112+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121",
"pip freeze",
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def build_common_components(
root_dir=root_dir,
cmd=[script, cmd_to_launch, run_name, cluster, *overrides],
cluster=cluster,
nccl_debug=False,
)

beaker_user = get_beaker_username()
Expand Down
51 changes: 38 additions & 13 deletions src/olmo_core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,16 +532,25 @@ def _get_gcs_client():


def _gcs_is_retriable(exc: Exception) -> bool:
from google.api_core.exceptions import BadRequest
from google.api_core.retry import if_transient_error

return if_transient_error(exc) or isinstance(exc, requests.exceptions.Timeout)
return (
if_transient_error(exc)
or isinstance(exc, requests.exceptions.Timeout)
or isinstance(exc, BadRequest) # Weird choice, but Google throws this transiently
)


def _get_gcs_retry():
from google.api_core.retry import Retry

return Retry(
predicate=_gcs_is_retriable, initial=1.0, maximum=10.0, multiplier=2.0, timeout=600.0
predicate=_gcs_is_retriable, # NOTE: it appears google might ignore this
initial=1.0,
maximum=10.0,
multiplier=2.0,
timeout=600.0,
)


Expand All @@ -554,7 +563,7 @@ def _get_gcs_conditional_retry():
return ConditionalRetryPolicy(_get_gcs_retry(), is_generation_specified, ["query_params"])


@retriable()
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_file_size(bucket_name: str, key: str) -> int:
from google.api_core.exceptions import NotFound

Expand All @@ -569,35 +578,51 @@ def _gcs_file_size(bucket_name: str, key: str) -> int:
return blob.size


@retriable()
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
from google.api_core.exceptions import NotFound

storage_client = _get_gcs_client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
try:
blob.reload()
blob.reload(retry=_get_gcs_retry())
except NotFound:
raise FileNotFoundError(f"gs://{bucket_name}/{key}")
return blob.download_as_bytes(
start=bytes_start, end=bytes_start + num_bytes - 1, retry=_get_gcs_retry()
start=bytes_start,
end=bytes_start + num_bytes - 1,
retry=_get_gcs_retry(),
checksum=None, # type: ignore
)


@retriable()
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
storage_client = _get_gcs_client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
if not save_overwrite and blob.exists():
raise FileExistsError(
f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)
blob.upload_from_filename(source, retry=_get_gcs_conditional_retry())

generation: int = 0
if blob.exists(retry=_get_gcs_retry()):
if not save_overwrite:
raise FileExistsError(
f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)

@retriable()
blob.reload(retry=_get_gcs_retry())
assert blob.generation is not None
generation = blob.generation

blob.upload_from_filename(
source,
if_generation_match=generation,
retry=_get_gcs_conditional_retry(),
checksum=None,
)


@retriable(retry_condition=_gcs_is_retriable)
def _gcs_clear_directory(bucket_name: str, prefix: str):
from google.api_core.exceptions import NotFound

Expand Down
13 changes: 12 additions & 1 deletion src/olmo_core/launch/beaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,23 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec:
"#!/usr/bin/env bash",
"set -exuo pipefail",
"[[ -d /var/lib/tcpxo/lib64 ]] && export LD_LIBRARY_PATH=/var/lib/tcpxo/lib64:$LD_LIBRARY_PATH",
# Setup the kernel cache directory used by pytorch
"mkdir -p /root/.cache/torch/kernels && export PYTORCH_KERNEL_CACHE_PATH=/root/.cache/torch/kernels",
"mkdir -p /olmo-core-runtime",
"cd /olmo-core-runtime",
*self.setup_steps,
]

if torchrun:
if any(["augusta" in cluster for cluster in self.clusters]):
entrypoint_script.append(
"export BEAKER_REPLICA_RANK=$("
"python -m olmo_core.launch.reorder_ranks_in_gcp "
"${BEAKER_REPLICA_RANK} "
"${BEAKER_REPLICA_COUNT} "
"${BEAKER_LEADER_REPLICA_HOSTNAME}"
")"
)
entrypoint_script.append(" ".join(self._get_torchrun_cmd()) + ' "$@"')
else:
entrypoint_script.append('python "$@"')
Expand All @@ -341,7 +352,7 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec:
leader_selection=self.num_nodes > 1,
host_networking=self.num_nodes > 1
or any(["augusta" in cluster for cluster in self.clusters]),
propagate_failure=True if self.num_nodes > 1 else None,
propagate_failure=False if self.num_nodes > 1 else None,
propagate_preemption=True if self.num_nodes > 1 else None,
synchronized_start_timeout="90m" if self.num_nodes > 1 else None,
resources=TaskResources(gpu_count=self.num_gpus, shared_memory="10GiB"),
Expand Down
70 changes: 70 additions & 0 deletions src/olmo_core/launch/reorder_ranks_in_gcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import argparse
import sys

import requests
import torch.distributed as dist
from urllib3.exceptions import MaxRetryError, NameResolutionError


def main():
parser = argparse.ArgumentParser()
parser.add_argument("rank", type=int, help="Worker number")
parser.add_argument("world_size", type=int, help="Total number of workers")
parser.add_argument("master_addr", help="Hostname of worker 0")
parser.add_argument("--master_port", type=int, default=29501, help="Port for TCPStore")
parser.add_argument("--debug", action="store_true", help="Enable debug mode (outside of GCP)")
args = parser.parse_args()

# Create or connect to the store
store = dist.TCPStore(
host_name=args.master_addr,
port=args.master_port,
world_size=args.world_size,
is_master=(args.rank == 0),
)

# Get our own host id
if args.debug:
import socket

host_id = f"{socket.gethostname()}_{args.rank}"
else:
try:
response = requests.get(
"http://metadata.google.internal/computeMetadata/v1/instance/attributes/physical_host",
headers={"Metadata-Flavor": "Google"},
)
assert response.status_code == 200
host_id = response.text.strip()
except requests.exceptions.ConnectionError as e:
# Unwrap the exception
e = e.args[0]
if not isinstance(e, MaxRetryError):
raise
e = e.reason
if not isinstance(e, NameResolutionError):
raise
# Seems we called this outside of GCP, so we do nothing and just print our original rank.
print(args.rank)
sys.exit(0)

# Find the index of our host id
store.set(f"node_{args.rank}_hostid", host_id)
store.wait([f"node_{i}_hostid" for i in range(args.world_size)])
all_host_ids = [store.get(f"node_{i}_hostid").decode("UTF-8") for i in range(args.world_size)]
assert len(set(all_host_ids)) == len(all_host_ids)
assert host_id in all_host_ids
rank0_host_id = all_host_ids[0]
all_host_ids.sort()
# Rank 0 needs to remain rank 0, so we reshuffle around it
rank0_index = all_host_ids.index(rank0_host_id)
all_host_ids = all_host_ids[rank0_index:] + all_host_ids[:rank0_index]
print(all_host_ids.index(host_id))

# Make sure we're all done before exiting
store.set(f"node_{args.rank}_done", host_id)
store.wait([f"node_{i}_done" for i in range(args.world_size)])


if __name__ == "__main__":
main()
15 changes: 9 additions & 6 deletions src/olmo_core/nn/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,19 +460,22 @@ def olmo2_13B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
)

@classmethod
def olmo2_26B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
def olmo2_32B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
"""
A 26B OLMo model config.
A 32B OLMo model config.
"""
d_model = 5120
return cls.llama_like(
vocab_size=vocab_size,
d_model=7168,
n_layers=kwargs.pop("n_layers", 40),
n_heads=kwargs.pop("n_heads", 56),
d_model=d_model,
n_layers=kwargs.pop("n_layers", 64),
n_heads=kwargs.pop("n_heads", 40),
n_kv_heads=kwargs.pop("n_kv_heads", 8),
block_name=kwargs.pop("block_name", TransformerBlockType.reordered_norm),
qk_norm=kwargs.pop("qk_norm", True),
rope_theta=kwargs.pop("rope_theta", 500_000),
hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 1024),
hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 512),
hidden_size_multiplier=kwargs.pop("hidden_size_multiplier", 27648 / (8 * d_model / 3)),
layer_norm_eps=1e-6,
**kwargs,
)
Expand Down
4 changes: 3 additions & 1 deletion src/olmo_core/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .adam import AdamConfig
from .adamw import AdamWConfig
from .adamw import AdamWConfig, SkipStepAdamW, SkipStepAdamWConfig
from .config import OptimConfig, OptimGroupOverride
from .lion import Lion, LionConfig, SkipStepLion, SkipStepLionConfig
from .scheduler import (
Expand All @@ -18,6 +18,8 @@
"OptimGroupOverride",
"SkipStepOptimizer",
"AdamWConfig",
"SkipStepAdamWConfig",
"SkipStepAdamW",
"AdamConfig",
"LionConfig",
"Lion",
Expand Down
Loading

0 comments on commit 7633461

Please sign in to comment.