Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

32B with Hybrid Sharding #134

Draft
wants to merge 67 commits into
base: 32B
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
1a6adde
Experiments with hybrid sharding
dirkgr Dec 23, 2024
638d8b2
Updated notebook
dirkgr Dec 23, 2024
6cea547
One replica per node
dirkgr Dec 23, 2024
bbe0d98
Two nodes
dirkgr Dec 24, 2024
4a567d6
Updated dashboard
dirkgr Dec 24, 2024
8c666a5
Load optimizer and model in the other order
dirkgr Dec 24, 2024
31555fb
Update the notebook
dirkgr Dec 24, 2024
1c4ae16
Updated notebook
dirkgr Dec 27, 2024
d315524
Merge remote-tracking branch 'origin/32B' into dirkg/32BHybrid
dirkgr Dec 30, 2024
4f85f4d
Problem doesn't repro with four nodes
dirkgr Dec 30, 2024
2ce06ae
SkipAdam optimizer has problems during startup?
dirkgr Dec 30, 2024
1558306
Go back to FSDP to see if we can survive more than 9 steps
dirkgr Dec 30, 2024
6e92ed2
Merge remote-tracking branch 'origin/32B' into dirkg/32BHybrid
dirkgr Jan 3, 2025
fe2bb8b
Turn hybrid back on
dirkgr Jan 3, 2025
068b96e
Get faster answers
dirkgr Jan 3, 2025
56be57c
More informative error message
dirkgr Jan 3, 2025
ad773fe
Forgot about GPUs
dirkgr Jan 3, 2025
d8c2d2a
Try it with more nodes
dirkgr Jan 3, 2025
f751c77
Put SkipStep back
dirkgr Jan 4, 2025
388f30c
Don't have that many nodes available now
dirkgr Jan 5, 2025
79d3370
Merge remote-tracking branch 'origin/32B' into dirkg/32BHybrid
dirkgr Jan 6, 2025
4cae22b
Adds a way to load unsharded checkpoints
dirkgr Jan 6, 2025
a2b9bdd
Turn off compiling the optimizer
dirkgr Jan 6, 2025
b5aa0f8
Checkpoint directories work different now
dirkgr Jan 6, 2025
9f65f58
30 minute timeout
dirkgr Jan 7, 2025
ed5a94c
Support `key_mapping` when loading unsharded checkpoints
dirkgr Jan 7, 2025
94a08bc
Don't broadcast from 0, because apparently that hangs
dirkgr Jan 7, 2025
fff3276
More logging
dirkgr Jan 7, 2025
549693b
I think this is always `None`, but I want to be sure.
dirkgr Jan 7, 2025
18a96f4
Better log message
dirkgr Jan 7, 2025
98a5a2d
Even better messaging
dirkgr Jan 7, 2025
0b957c6
Log all threads before downloading
dirkgr Jan 7, 2025
df62dd5
Let's see download progress
dirkgr Jan 7, 2025
9afadd2
Make printing stack traces more readable
dirkgr Jan 7, 2025
0b8f047
Don't log all threads
dirkgr Jan 7, 2025
7166a09
15 minute timeout
dirkgr Jan 7, 2025
f4f2a58
Revert "Let's see download progress"
dirkgr Jan 7, 2025
67547b5
Use the rank we mean
dirkgr Jan 7, 2025
56c53b2
Finalize config after applying overrides
dirkgr Jan 7, 2025
16d99b4
Download checkpoint before doing anything else
dirkgr Jan 7, 2025
8e32582
We don't have gsutil in the path.
dirkgr Jan 7, 2025
2e976a7
Install google cloud CLI in the Docker container
dirkgr Jan 7, 2025
9457dcc
Installing this way doesn't work
dirkgr Jan 7, 2025
a93d01d
Less output from `gsutil`
dirkgr Jan 8, 2025
899686a
Disable evals so we can experiment faster
dirkgr Jan 8, 2025
50f60c3
Revert "Disable evals so we can experiment faster"
dirkgr Jan 8, 2025
0caa78a
Adds an unsharding script
dirkgr Jan 8, 2025
05979c2
Adds gcloud utilities to the image
dirkgr Jan 8, 2025
034dea9
Merge branch 'dirkg/32BHybrid' of https://github.com/allenai/OLMo-cor…
dirkgr Jan 8, 2025
24db150
Bring back selective activation checkpointing
dirkgr Jan 8, 2025
82dd576
Merge branch 'dirkg/32BHybrid' of https://github.com/allenai/OLMo-cor…
dirkgr Jan 8, 2025
7d0b268
Checkpoint more stuff
dirkgr Jan 8, 2025
e55b84b
Fix directory naming
dirkgr Jan 8, 2025
51ae3a4
Initialize modules that are wrapped
dirkgr Jan 8, 2025
a2f010e
More informative error
dirkgr Jan 8, 2025
8d69ce1
Build parameter groups even with fancy activation checkpointing schemes
dirkgr Jan 8, 2025
c13e4cf
Don't preserve randomness
dirkgr Jan 8, 2025
e39a11c
We can use globbing to checkpoint
dirkgr Jan 8, 2025
aace4af
Switch back to full sharding
dirkgr Jan 8, 2025
c79584e
More precise checkpointing
dirkgr Jan 8, 2025
e1f5956
Make a baseline with 16 nodes
dirkgr Jan 8, 2025
af57282
Checkpoint attention
dirkgr Jan 9, 2025
1759a53
Checkpoint some more
dirkgr Jan 9, 2025
5395984
Revert "Checkpoint some more"
dirkgr Jan 9, 2025
76db353
Revert "Checkpoint attention"
dirkgr Jan 9, 2025
50d93cf
Reorder ranks in GCP
dirkgr Jan 9, 2025
3a6ebff
Rank 0 needs to remain rank 0
dirkgr Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
git && \
rm -rf /var/lib/apt/lists/*

# Install google cloud CLI
RUN apt-get update && apt-get install -y apt-transport-https ca-certificates curl gnupg && \
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \
echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \
apt-get update && \
apt-get install -y google-cloud-cli

# Install MLNX OFED user-space drivers
# See https://docs.nvidia.com/networking/pages/releaseview.action?pageId=15049785#Howto:DeployRDMAacceleratedDockercontaineroverInfiniBandfabric.-Dockerfile
ENV MOFED_VER="24.01-0.3.3.1"
Expand Down
2 changes: 1 addition & 1 deletion src/olmo_core/data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def __init__(
assert isinstance(self.dataset, NumpyFSLDataset)
if self.rank_batch_size % self.dataset.sequence_length != 0:
raise OLMoConfigurationError(
"rank batch size (in tokens) must be divisible by sequence length"
f"rank batch size (in tokens) must be divisible by sequence length; got rbs={self.rank_batch_size}, sl={self.dataset.sequence_length}"
)

@property
Expand Down
172 changes: 118 additions & 54 deletions src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
from torch.distributed.checkpoint.metadata import Metadata

from olmo_core.aliases import PathOrStr
from olmo_core.io import clear_directory, dir_is_empty, is_url, normalize_path
from olmo_core.utils import gc_cuda, wait_for
from olmo_core.io import clear_directory, dir_is_empty, is_url, normalize_path, resource_path, file_exists
from olmo_core.utils import gc_cuda, wait_for, log_all_threads
from . import safetensors_util

from ..utils import barrier, get_fs_local_rank, is_distributed
from .filesystem import RemoteFileSystemReader, RemoteFileSystemWriter
Expand Down Expand Up @@ -207,71 +208,134 @@ def load_model_and_optim_state(
:param work_dir: A working directory for caching files/directories.
:param thread_count: Set the number of threads used for certain operations.
"""
dir = normalize_path(dir)
state_dict = _prepare_state_dict(model, optim, process_group=process_group)
reader = RemoteFileSystemReader(
dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir
)
assert process_group is None

if key_mapping is not None:
metadata = reader.read_metadata()
for current_key, original_key in key_mapping.items():
if f"model.{original_key}" not in metadata.state_dict_metadata:
continue
dir = normalize_path(dir)

log.info(f"Mapping current param '{current_key}' to '{original_key}' in checkpoint")
state_dict["model"][original_key] = state_dict["model"].pop(current_key)
can_load_unsharded =(
file_exists(f"{dir}_unsharded/model.safetensors") and
file_exists(f"{dir}_unsharded/optim.safetensors")
)

if optim is None:
continue
if can_load_unsharded:
if get_fs_local_rank() == 0:
log.info(f"Local rank 0 loading {dir}/model.safetensors")
model_path = resource_path(dir, "model.safetensors", local_cache=work_dir)
log.info(f"Local rank 0 loaded {dir}/model.safetensors")
dist.barrier()
else:
log.info("Nonzero local rank waiting for rank 0 to load model.safetensors")
dist.barrier()
log.info("Nonzero local rank loading model.safetensors")
model_path = resource_path(dir, "model.safetensors", local_cache=work_dir)
log.info("Nonzero local rank loaded model.safetensors")

model_state_dict = safetensors_util.safetensors_file_to_state_dict(model_path)
if key_mapping is not None:
for current_key, original_key in key_mapping.items():
if original_key in model_state_dict:
assert current_key not in model_state_dict, f"Mapping {original_key} to {current_key} in the model state dict would overwrite existing {current_key}"
model_state_dict[current_key] = model_state_dict.pop(original_key)

sd_options = dist_cp_sd.StateDictOptions(
strict=True,
full_state_dict=True,
broadcast_from_rank0=False
)
dist_cp_sd.set_model_state_dict(model, model_state_dict, options=sd_options)
del model_path
del model_state_dict
gc_cuda()

state_dict["optim"]["state"][original_key] = state_dict["optim"]["state"].pop(
current_key
)
for group in state_dict["optim"]["param_groups"]:
if current_key in group["params"]:
idx = group["params"].index(current_key)
group["params"][idx] = original_key
break
if optim is not None:
if get_fs_local_rank() == 0:
optim_path = resource_path(dir, "optim.safetensors", local_cache=work_dir)
dist.barrier()
else:
dist.barrier()
optim_path = resource_path(dir, "optim.safetensors", local_cache=work_dir)

optim_state_dict = safetensors_util.safetensors_file_to_state_dict(optim_path)
if key_mapping is not None:
for current_key, original_key in key_mapping.items():
if original_key in optim_state_dict["state"]:
assert current_key not in optim_state_dict["state"], f"Mapping {original_key} to {current_key} in the optimizer state dict would overwrite existing {current_key}"
optim_state_dict["state"][current_key] = optim_state_dict["state"].pop(original_key)
for group in optim_state_dict["param_groups"]:
if original_key in group["params"]:
idx = group["params"].index(original_key)
group["params"][idx] = current_key
break

dist_cp_sd.set_optimizer_state_dict(model, optim, optim_state_dict, options=sd_options)
del optim_path
del optim_state_dict
gc_cuda()
else:
state_dict = _prepare_state_dict(model, optim, process_group=process_group)
reader = RemoteFileSystemReader(
dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir
)

dist_cp.load(
state_dict,
checkpoint_id=dir,
storage_reader=reader,
process_group=process_group,
)
if key_mapping is not None:
metadata = reader.read_metadata()
for current_key, original_key in key_mapping.items():
if f"model.{original_key}" not in metadata.state_dict_metadata:
continue

log.info(f"Mapping current param '{current_key}' to '{original_key}' in checkpoint")
state_dict["model"][original_key] = state_dict["model"].pop(current_key)

if optim is None:
continue

state_dict["optim"]["state"][original_key] = state_dict["optim"]["state"].pop(
current_key
)
for group in state_dict["optim"]["param_groups"]:
if current_key in group["params"]:
idx = group["params"].index(current_key)
group["params"][idx] = original_key
break

dist_cp.load(
state_dict,
checkpoint_id=dir,
storage_reader=reader,
process_group=process_group,
)

if key_mapping is not None:
metadata = reader.read_metadata()
for current_key, original_key in key_mapping.items():
if f"model.{original_key}" not in metadata.state_dict_metadata:
continue
if key_mapping is not None:
metadata = reader.read_metadata()
for current_key, original_key in key_mapping.items():
if f"model.{original_key}" not in metadata.state_dict_metadata:
continue

state_dict["model"][current_key] = state_dict["model"].pop(original_key)
state_dict["model"][current_key] = state_dict["model"].pop(original_key)

if optim is None:
continue
if optim is None:
continue

state_dict["optim"]["state"][current_key] = state_dict["optim"]["state"].pop(
original_key
)
for group in state_dict["optim"]["param_groups"]:
if original_key in group["params"]:
idx = group["params"].index(original_key)
group["params"][idx] = current_key
break

dist_cp_sd.set_model_state_dict(
model, state_dict["model"], options=dist_cp_sd.StateDictOptions(strict=True)
)
gc_cuda()
state_dict["optim"]["state"][current_key] = state_dict["optim"]["state"].pop(
original_key
)
for group in state_dict["optim"]["param_groups"]:
if original_key in group["params"]:
idx = group["params"].index(original_key)
group["params"][idx] = current_key
break

if optim is not None:
dist_cp_sd.set_optimizer_state_dict(
model, optim, state_dict["optim"], options=dist_cp_sd.StateDictOptions(strict=True)
dist_cp_sd.set_model_state_dict(
model, state_dict["model"], options=dist_cp_sd.StateDictOptions(strict=True)
)
gc_cuda()

if optim is not None:
dist_cp_sd.set_optimizer_state_dict(
model, optim, state_dict["optim"], options=dist_cp_sd.StateDictOptions(strict=True)
)
gc_cuda()


def unshard_checkpoint(
dir: PathOrStr,
Expand Down
82 changes: 82 additions & 0 deletions src/olmo_core/distributed/checkpoint/safetensors_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import base64
import pickle
from dataclasses import dataclass
from typing import Dict, Optional, Tuple

import safetensors.torch
import torch


__all__ = [
"state_dict_to_safetensors_file",
"safetensors_file_to_state_dict",
]

from olmo_core.aliases import PathOrStr


@dataclass(eq=True, frozen=True)
class STKey:
keys: Tuple
value_is_pickled: bool


def encode_key(key: STKey) -> str:
b = pickle.dumps((key.keys, key.value_is_pickled))
b = base64.urlsafe_b64encode(b)
return str(b, "ASCII")


def decode_key(key: str) -> STKey:
b = base64.urlsafe_b64decode(key)
keys, value_is_pickled = pickle.loads(b)
return STKey(keys, value_is_pickled)


def flatten_dict(d: Dict) -> Dict[STKey, torch.Tensor]:
result = {}
for key, value in d.items():
if isinstance(value, torch.Tensor):
result[STKey((key,), False)] = value
elif isinstance(value, dict):
value = flatten_dict(value)
for inner_key, inner_value in value.items():
result[STKey((key,) + inner_key.keys, inner_key.value_is_pickled)] = inner_value
else:
pickled = bytearray(pickle.dumps(value))
pickled_tensor = torch.frombuffer(pickled, dtype=torch.uint8)
result[STKey((key,), True)] = pickled_tensor
return result


def unflatten_dict(d: Dict[STKey, torch.Tensor]) -> Dict:
result: Dict = {}

for key, value in d.items():
if key.value_is_pickled:
value = pickle.loads(value.numpy().data)

target_dict = result
for k in key.keys[:-1]:
new_target_dict = target_dict.get(k)
if new_target_dict is None:
new_target_dict = {}
target_dict[k] = new_target_dict
target_dict = new_target_dict
target_dict[key.keys[-1]] = value

return result


def state_dict_to_safetensors_file(state_dict: Dict, filename: PathOrStr):
state_dict = flatten_dict(state_dict)
state_dict = {encode_key(k): v for k, v in state_dict.items()}
safetensors.torch.save_file(state_dict, filename)


def safetensors_file_to_state_dict(filename: PathOrStr, map_location: Optional[str] = None) -> Dict:
if map_location is None:
map_location = "cpu"
state_dict = safetensors.torch.load_file(filename, device=map_location)
state_dict = {decode_key(k): v for k, v in state_dict.items()}
return unflatten_dict(state_dict)
3 changes: 2 additions & 1 deletion src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut
set_env_var("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7")
set_env_var("NCCL_NET_GDR_LEVEL", "PIX")
set_env_var("NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING", "0")
set_env_var("NCCL_FASTRAK_PLUGIN_ACCEPT_TIMEOUT_MS", "600000")
set_env_var("NCCL_FASTRAK_PLUGIN_ACCEPT_TIMEOUT_MS", str(30 * 60 * 1000))
set_env_var("NCCL_NVLS_ENABLE", "0")
set_env_var("NCCL_USE_SNAP", "1")
set_env_var("NCCL_FASTRAK_USE_LLCM", "1")
Expand All @@ -93,6 +93,7 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut
)
set_env_var("NCCL_SOCKET_IFNAME", "enp0s12")
set_env_var("NCCL_DEBUG_SUBSYS", "INIT,NET")
set_env_var("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", str(15 * 60))

if backend_supports_cuda(backend):
# Set CUDA device.
Expand Down
4 changes: 2 additions & 2 deletions src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,11 @@ def build_config(
trainer=trainer,
)

config = config.merge(overrides)

if finalize_config is not None:
finalize_config(config)

config = config.merge(overrides)

if config.model.float8_config is not None and config.model.float8_config.enabled:
config.trainer.add_callback(
"float8_handler", Float8HandlerCallback(config=config.model.float8_config)
Expand Down
8 changes: 8 additions & 0 deletions src/olmo_core/launch/beaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,14 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec:
]

if torchrun:
entrypoint_script.append(
"export BEAKER_REPLICA_RANK=$("
"python src/scripts/reorder_ranks_in_gcp.py "
"${BEAKER_REPLICA_RANK} "
"${BEAKER_REPLICA_COUNT} "
"${BEAKER_LEADER_REPLICA_HOSTNAME}"
")"
)
entrypoint_script.append(" ".join(self._get_torchrun_cmd()) + ' "$@"')
else:
entrypoint_script.append('python "$@"')
Expand Down
4 changes: 4 additions & 0 deletions src/olmo_core/nn/transformer/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ActivationWrapper

from olmo_core.config import StrEnum

Expand Down Expand Up @@ -76,6 +77,9 @@ def init_attention(
if self == InitMethod.normalized:
std = d_model**-0.5

if isinstance(m, ActivationWrapper):
m = m._checkpoint_wrapped_module

if isinstance(m, Attention):
for w in (m.w_q, m.w_k, m.w_v):
self._init_linear(w, std=std, generator=generator)
Expand Down
Loading