diff --git a/src/Dockerfile b/src/Dockerfile index 186cedbb..485a6a55 100644 --- a/src/Dockerfile +++ b/src/Dockerfile @@ -57,6 +57,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ git && \ rm -rf /var/lib/apt/lists/* +# Install google cloud CLI +RUN apt-get update && apt-get install -y apt-transport-https ca-certificates curl gnupg && \ + curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \ + echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \ + apt-get update && \ + apt-get install -y google-cloud-cli + # Install MLNX OFED user-space drivers # See https://docs.nvidia.com/networking/pages/releaseview.action?pageId=15049785#Howto:DeployRDMAacceleratedDockercontaineroverInfiniBandfabric.-Dockerfile ENV MOFED_VER="24.01-0.3.3.1" diff --git a/src/olmo_core/data/data_loader.py b/src/olmo_core/data/data_loader.py index d9323757..e88bd08b 100644 --- a/src/olmo_core/data/data_loader.py +++ b/src/olmo_core/data/data_loader.py @@ -489,7 +489,7 @@ def __init__( assert isinstance(self.dataset, NumpyFSLDataset) if self.rank_batch_size % self.dataset.sequence_length != 0: raise OLMoConfigurationError( - "rank batch size (in tokens) must be divisible by sequence length" + f"rank batch size (in tokens) must be divisible by sequence length; got rbs={self.rank_batch_size}, sl={self.dataset.sequence_length}" ) @property diff --git a/src/olmo_core/distributed/checkpoint/__init__.py b/src/olmo_core/distributed/checkpoint/__init__.py index 068fc3e8..7465f9e8 100644 --- a/src/olmo_core/distributed/checkpoint/__init__.py +++ b/src/olmo_core/distributed/checkpoint/__init__.py @@ -39,8 +39,9 @@ from torch.distributed.checkpoint.metadata import Metadata from olmo_core.aliases import PathOrStr -from olmo_core.io import clear_directory, dir_is_empty, is_url, normalize_path -from olmo_core.utils import gc_cuda, wait_for +from olmo_core.io import clear_directory, dir_is_empty, is_url, normalize_path, resource_path, file_exists +from olmo_core.utils import gc_cuda, wait_for, log_all_threads +from . import safetensors_util from ..utils import barrier, get_fs_local_rank, is_distributed from .filesystem import RemoteFileSystemReader, RemoteFileSystemWriter @@ -207,71 +208,134 @@ def load_model_and_optim_state( :param work_dir: A working directory for caching files/directories. :param thread_count: Set the number of threads used for certain operations. """ - dir = normalize_path(dir) - state_dict = _prepare_state_dict(model, optim, process_group=process_group) - reader = RemoteFileSystemReader( - dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir - ) + assert process_group is None - if key_mapping is not None: - metadata = reader.read_metadata() - for current_key, original_key in key_mapping.items(): - if f"model.{original_key}" not in metadata.state_dict_metadata: - continue + dir = normalize_path(dir) - log.info(f"Mapping current param '{current_key}' to '{original_key}' in checkpoint") - state_dict["model"][original_key] = state_dict["model"].pop(current_key) + can_load_unsharded =( + file_exists(f"{dir}_unsharded/model.safetensors") and + file_exists(f"{dir}_unsharded/optim.safetensors") + ) - if optim is None: - continue + if can_load_unsharded: + if get_fs_local_rank() == 0: + log.info(f"Local rank 0 loading {dir}/model.safetensors") + model_path = resource_path(dir, "model.safetensors", local_cache=work_dir) + log.info(f"Local rank 0 loaded {dir}/model.safetensors") + dist.barrier() + else: + log.info("Nonzero local rank waiting for rank 0 to load model.safetensors") + dist.barrier() + log.info("Nonzero local rank loading model.safetensors") + model_path = resource_path(dir, "model.safetensors", local_cache=work_dir) + log.info("Nonzero local rank loaded model.safetensors") + + model_state_dict = safetensors_util.safetensors_file_to_state_dict(model_path) + if key_mapping is not None: + for current_key, original_key in key_mapping.items(): + if original_key in model_state_dict: + assert current_key not in model_state_dict, f"Mapping {original_key} to {current_key} in the model state dict would overwrite existing {current_key}" + model_state_dict[current_key] = model_state_dict.pop(original_key) + + sd_options = dist_cp_sd.StateDictOptions( + strict=True, + full_state_dict=True, + broadcast_from_rank0=False + ) + dist_cp_sd.set_model_state_dict(model, model_state_dict, options=sd_options) + del model_path + del model_state_dict + gc_cuda() - state_dict["optim"]["state"][original_key] = state_dict["optim"]["state"].pop( - current_key - ) - for group in state_dict["optim"]["param_groups"]: - if current_key in group["params"]: - idx = group["params"].index(current_key) - group["params"][idx] = original_key - break + if optim is not None: + if get_fs_local_rank() == 0: + optim_path = resource_path(dir, "optim.safetensors", local_cache=work_dir) + dist.barrier() + else: + dist.barrier() + optim_path = resource_path(dir, "optim.safetensors", local_cache=work_dir) + + optim_state_dict = safetensors_util.safetensors_file_to_state_dict(optim_path) + if key_mapping is not None: + for current_key, original_key in key_mapping.items(): + if original_key in optim_state_dict["state"]: + assert current_key not in optim_state_dict["state"], f"Mapping {original_key} to {current_key} in the optimizer state dict would overwrite existing {current_key}" + optim_state_dict["state"][current_key] = optim_state_dict["state"].pop(original_key) + for group in optim_state_dict["param_groups"]: + if original_key in group["params"]: + idx = group["params"].index(original_key) + group["params"][idx] = current_key + break + + dist_cp_sd.set_optimizer_state_dict(model, optim, optim_state_dict, options=sd_options) + del optim_path + del optim_state_dict + gc_cuda() + else: + state_dict = _prepare_state_dict(model, optim, process_group=process_group) + reader = RemoteFileSystemReader( + dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir + ) - dist_cp.load( - state_dict, - checkpoint_id=dir, - storage_reader=reader, - process_group=process_group, - ) + if key_mapping is not None: + metadata = reader.read_metadata() + for current_key, original_key in key_mapping.items(): + if f"model.{original_key}" not in metadata.state_dict_metadata: + continue + + log.info(f"Mapping current param '{current_key}' to '{original_key}' in checkpoint") + state_dict["model"][original_key] = state_dict["model"].pop(current_key) + + if optim is None: + continue + + state_dict["optim"]["state"][original_key] = state_dict["optim"]["state"].pop( + current_key + ) + for group in state_dict["optim"]["param_groups"]: + if current_key in group["params"]: + idx = group["params"].index(current_key) + group["params"][idx] = original_key + break + + dist_cp.load( + state_dict, + checkpoint_id=dir, + storage_reader=reader, + process_group=process_group, + ) - if key_mapping is not None: - metadata = reader.read_metadata() - for current_key, original_key in key_mapping.items(): - if f"model.{original_key}" not in metadata.state_dict_metadata: - continue + if key_mapping is not None: + metadata = reader.read_metadata() + for current_key, original_key in key_mapping.items(): + if f"model.{original_key}" not in metadata.state_dict_metadata: + continue - state_dict["model"][current_key] = state_dict["model"].pop(original_key) + state_dict["model"][current_key] = state_dict["model"].pop(original_key) - if optim is None: - continue + if optim is None: + continue - state_dict["optim"]["state"][current_key] = state_dict["optim"]["state"].pop( - original_key - ) - for group in state_dict["optim"]["param_groups"]: - if original_key in group["params"]: - idx = group["params"].index(original_key) - group["params"][idx] = current_key - break - - dist_cp_sd.set_model_state_dict( - model, state_dict["model"], options=dist_cp_sd.StateDictOptions(strict=True) - ) - gc_cuda() + state_dict["optim"]["state"][current_key] = state_dict["optim"]["state"].pop( + original_key + ) + for group in state_dict["optim"]["param_groups"]: + if original_key in group["params"]: + idx = group["params"].index(original_key) + group["params"][idx] = current_key + break - if optim is not None: - dist_cp_sd.set_optimizer_state_dict( - model, optim, state_dict["optim"], options=dist_cp_sd.StateDictOptions(strict=True) + dist_cp_sd.set_model_state_dict( + model, state_dict["model"], options=dist_cp_sd.StateDictOptions(strict=True) ) gc_cuda() + if optim is not None: + dist_cp_sd.set_optimizer_state_dict( + model, optim, state_dict["optim"], options=dist_cp_sd.StateDictOptions(strict=True) + ) + gc_cuda() + def unshard_checkpoint( dir: PathOrStr, diff --git a/src/olmo_core/distributed/checkpoint/safetensors_util.py b/src/olmo_core/distributed/checkpoint/safetensors_util.py new file mode 100644 index 00000000..475f6dc9 --- /dev/null +++ b/src/olmo_core/distributed/checkpoint/safetensors_util.py @@ -0,0 +1,82 @@ +import base64 +import pickle +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import safetensors.torch +import torch + + +__all__ = [ + "state_dict_to_safetensors_file", + "safetensors_file_to_state_dict", +] + +from olmo_core.aliases import PathOrStr + + +@dataclass(eq=True, frozen=True) +class STKey: + keys: Tuple + value_is_pickled: bool + + +def encode_key(key: STKey) -> str: + b = pickle.dumps((key.keys, key.value_is_pickled)) + b = base64.urlsafe_b64encode(b) + return str(b, "ASCII") + + +def decode_key(key: str) -> STKey: + b = base64.urlsafe_b64decode(key) + keys, value_is_pickled = pickle.loads(b) + return STKey(keys, value_is_pickled) + + +def flatten_dict(d: Dict) -> Dict[STKey, torch.Tensor]: + result = {} + for key, value in d.items(): + if isinstance(value, torch.Tensor): + result[STKey((key,), False)] = value + elif isinstance(value, dict): + value = flatten_dict(value) + for inner_key, inner_value in value.items(): + result[STKey((key,) + inner_key.keys, inner_key.value_is_pickled)] = inner_value + else: + pickled = bytearray(pickle.dumps(value)) + pickled_tensor = torch.frombuffer(pickled, dtype=torch.uint8) + result[STKey((key,), True)] = pickled_tensor + return result + + +def unflatten_dict(d: Dict[STKey, torch.Tensor]) -> Dict: + result: Dict = {} + + for key, value in d.items(): + if key.value_is_pickled: + value = pickle.loads(value.numpy().data) + + target_dict = result + for k in key.keys[:-1]: + new_target_dict = target_dict.get(k) + if new_target_dict is None: + new_target_dict = {} + target_dict[k] = new_target_dict + target_dict = new_target_dict + target_dict[key.keys[-1]] = value + + return result + + +def state_dict_to_safetensors_file(state_dict: Dict, filename: PathOrStr): + state_dict = flatten_dict(state_dict) + state_dict = {encode_key(k): v for k, v in state_dict.items()} + safetensors.torch.save_file(state_dict, filename) + + +def safetensors_file_to_state_dict(filename: PathOrStr, map_location: Optional[str] = None) -> Dict: + if map_location is None: + map_location = "cpu" + state_dict = safetensors.torch.load_file(filename, device=map_location) + state_dict = {decode_key(k): v for k, v in state_dict.items()} + return unflatten_dict(state_dict) diff --git a/src/olmo_core/distributed/utils.py b/src/olmo_core/distributed/utils.py index 24a0a784..a6f7584e 100644 --- a/src/olmo_core/distributed/utils.py +++ b/src/olmo_core/distributed/utils.py @@ -67,7 +67,7 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut set_env_var("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7") set_env_var("NCCL_NET_GDR_LEVEL", "PIX") set_env_var("NCCL_FASTRAK_ENABLE_HOTPATH_LOGGING", "0") - set_env_var("NCCL_FASTRAK_PLUGIN_ACCEPT_TIMEOUT_MS", "600000") + set_env_var("NCCL_FASTRAK_PLUGIN_ACCEPT_TIMEOUT_MS", str(30 * 60 * 1000)) set_env_var("NCCL_NVLS_ENABLE", "0") set_env_var("NCCL_USE_SNAP", "1") set_env_var("NCCL_FASTRAK_USE_LLCM", "1") @@ -93,6 +93,7 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut ) set_env_var("NCCL_SOCKET_IFNAME", "enp0s12") set_env_var("NCCL_DEBUG_SUBSYS", "INIT,NET") + set_env_var("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", str(15 * 60)) if backend_supports_cuda(backend): # Set CUDA device. diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index 1015a60d..cdf0ee26 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -229,11 +229,11 @@ def build_config( trainer=trainer, ) + config = config.merge(overrides) + if finalize_config is not None: finalize_config(config) - config = config.merge(overrides) - if config.model.float8_config is not None and config.model.float8_config.enabled: config.trainer.add_callback( "float8_handler", Float8HandlerCallback(config=config.model.float8_config) diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index 435142f3..db660faf 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -323,6 +323,14 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec: ] if torchrun: + entrypoint_script.append( + "export BEAKER_REPLICA_RANK=$(" + "python src/scripts/reorder_ranks_in_gcp.py " + "${BEAKER_REPLICA_RANK} " + "${BEAKER_REPLICA_COUNT} " + "${BEAKER_LEADER_REPLICA_HOSTNAME}" + ")" + ) entrypoint_script.append(" ".join(self._get_torchrun_cmd()) + ' "$@"') else: entrypoint_script.append('python "$@"') diff --git a/src/olmo_core/nn/transformer/init.py b/src/olmo_core/nn/transformer/init.py index 56b55bc6..fd94e44c 100644 --- a/src/olmo_core/nn/transformer/init.py +++ b/src/olmo_core/nn/transformer/init.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ActivationWrapper from olmo_core.config import StrEnum @@ -76,6 +77,9 @@ def init_attention( if self == InitMethod.normalized: std = d_model**-0.5 + if isinstance(m, ActivationWrapper): + m = m._checkpoint_wrapped_module + if isinstance(m, Attention): for w in (m.w_q, m.w_k, m.w_v): self._init_linear(w, std=std, generator=generator) diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 929a1cb1..ed6fb8b0 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -4,15 +4,19 @@ import torch import torch.nn as nn +from torch.ao.nn.qat import Embedding from torch.distributed import DeviceMesh from olmo_core.config import StrEnum from olmo_core.data.utils import get_cumulative_document_lengths from olmo_core.doc_utils import beta_feature from olmo_core.utils import get_default_device +from ..attention import Attention from ..buffer_cache import BufferCache +from ..feed_forward import FeedForward from ..functional import l2_normalize +from ..layer_norm import RMSNorm from ..lm_head import LMHeadConfig from ..utils import selective_checkpointing_context_fn from .block import TransformerBlock, TransformerBlockConfig @@ -310,6 +314,7 @@ def apply_activation_checkpointing( # TODO: only preserve RNG state if dropout is active preserve_rng_state = True + modules_without_randomness = (Attention, RMSNorm, FeedForward, Embedding) if mode == TransformerActivationCheckpointingMode.selected_modules: from fnmatch import fnmatch @@ -324,7 +329,11 @@ def apply_activation_checkpointing( parent_name = ".".join(name.split(".")[:-1]) parent = self if not parent_name else self.get_submodule(parent_name) - module = ptd_checkpoint_wrapper(module, preserve_rng_state=preserve_rng_state) + + module = ptd_checkpoint_wrapper( + module, + preserve_rng_state=False if isinstance(module, modules_without_randomness) else preserve_rng_state + ) parent.register_module(name.split(".")[-1], module) log.info(f"Wrapped '{name}' for activation checkpointing") else: diff --git a/src/olmo_core/optim/config.py b/src/olmo_core/optim/config.py index f231f608..af5c65a3 100644 --- a/src/olmo_core/optim/config.py +++ b/src/olmo_core/optim/config.py @@ -71,6 +71,7 @@ def build_groups(self, model: nn.Module) -> Union[Iterable[torch.Tensor], List[D all_params: Dict[str, torch.Tensor] = OrderedDict() for n, p in model.named_parameters(): + n = n.replace("._checkpoint_wrapped_module.", ".") all_params[n] = p # Build groups. @@ -87,7 +88,8 @@ def build_groups(self, model: nn.Module) -> Union[Iterable[torch.Tensor], List[D if matches == 0: raise OLMoConfigurationError( - f"optim group {g_idx} override pattern '{pattern}' does not match any parameters" + f"optim group {g_idx} override pattern '{pattern}' does not match any parameters; " + f"valid names are: {', '.join(list(all_params.keys()))}" ) # Put any left-over params into a default group. diff --git a/src/olmo_core/utils.py b/src/olmo_core/utils.py index d618e827..806eb4e9 100644 --- a/src/olmo_core/utils.py +++ b/src/olmo_core/utils.py @@ -4,7 +4,9 @@ import os import socket import sys +import threading import time +import traceback import uuid import warnings from contextlib import contextmanager @@ -644,3 +646,11 @@ def cuda_sync_debug_mode(debug_mode: Union[int, str]): finally: if current_mode is not None: torch.cuda.set_sync_debug_mode(current_mode) + + +def log_all_threads(): + for thread in threading.enumerate(): + log.info(str(thread)) + + for item in traceback.StackSummary.from_list(traceback.extract_stack(sys._current_frames()[thread.ident])).format(): + log.info(str(item).strip()) diff --git a/src/scripts/reorder_ranks_in_gcp.py b/src/scripts/reorder_ranks_in_gcp.py new file mode 100644 index 00000000..dea8f0d4 --- /dev/null +++ b/src/scripts/reorder_ranks_in_gcp.py @@ -0,0 +1,71 @@ +import sys +from datetime import timedelta + +import requests +import torch.distributed as dist +import argparse + +from urllib3.exceptions import MaxRetryError, NameResolutionError + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("rank", type=int, help="Worker number") + parser.add_argument("world_size", type=int, help="Total number of workers") + parser.add_argument("master_addr", help="Hostname of worker 0") + parser.add_argument("--master_port", type=int, default=29501, help="Port for TCPStore") + parser.add_argument("--debug", action="store_true", help="Enable debug mode (outside of GCP)") + args = parser.parse_args() + + # Create or connect to the store + store = dist.TCPStore( + host_name=args.master_addr, + port=args.master_port, + world_size=args.world_size, + is_master=(args.rank == 0) + ) + + # Get our own host id + if args.debug: + import socket + host_id = f"{socket.gethostname()}_{args.rank}" + else: + try: + response = requests.get( + "http://metadata.google.internal/computeMetadata/v1/instance/attributes/physical_host", + headers={"Metadata-Flavor": "Google"} + ) + assert response.status_code == 200 + host_id = response.text.strip() + except requests.exceptions.ConnectionError as e: + # Unwrap the exception + e = e.args[0] + if not isinstance(e, MaxRetryError): + raise + e = e.reason + if not isinstance(e, NameResolutionError): + raise + # Seems we called this outside of GCP, so we do nothing and just print our original rank. + print(args.rank) + sys.exit(0) + + # Find the index of our host id + store.set(f"node_{args.rank}_hostid", host_id) + store.wait([f"node_{i}_hostid" for i in range(args.world_size)]) + all_host_ids = [store.get(f"node_{i}_hostid").decode("UTF-8") for i in range(args.world_size)] + assert len(set(all_host_ids)) == len(all_host_ids) + assert host_id in all_host_ids + rank0_host_id = all_host_ids[0] + all_host_ids.sort() + # Rank 0 needs to remain rank 0, so we reshuffle around it + rank0_index = all_host_ids.index(rank0_host_id) + all_host_ids = all_host_ids[rank0_index:] + all_host_ids[:rank0_index] + print(all_host_ids.index(host_id)) + + # Make sure we're all done before exiting + store.set(f"node_{args.rank}_done", host_id) + store.wait([f"node_{i}_done" for i in range(args.world_size)]) + + +if __name__ == "__main__": + main() diff --git a/src/scripts/train/OLMo2-32B.ipynb b/src/scripts/train/OLMo2-32B.ipynb index 3ae3cd0d..3f5c7e0a 100644 --- a/src/scripts/train/OLMo2-32B.ipynb +++ b/src/scripts/train/OLMo2-32B.ipynb @@ -1,1770 +1,1802 @@ { "cells": [ - { - "cell_type": "code", - "id": "initial_id", - "metadata": { - "collapsed": true, - "ExecuteTime": { - "end_time": "2025-01-05T19:02:26.034046Z", - "start_time": "2025-01-05T19:02:25.324969Z" - } - }, - "source": [ - "import os\n", - "from comet_ml.api import API\n", - "\n", - "comet_api = API(os.environ[\"COMETML_API_KEY\"])\n" - ], - "outputs": [], - "execution_count": 1 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-05T19:02:26.926559Z", - "start_time": "2025-01-05T19:02:26.046253Z" - } - }, - "cell_type": "code", - "source": [ - "exps = {\n", - " \"peteish32\": comet_api.get_experiments(\"ai2\", \"peteish32\", \"peteish32\"),\n", - " \"peteish13\": comet_api.get_experiments(\"ai2\", \"olmo-2-1124-13b\", \"OLMo-2-1124-13B-stage-1\"),\n", - " \"peteish7\": comet_api.get_experiments(\"ai2\", \"olmo-core-7b\", \"peteish7\")\n", - "}\n", - "\n", - "print(repr({k: len(v) for k, v in exps.items()}))" - ], - "id": "2c17abe415dabf07", - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'peteish32': 50, 'peteish13': 75, 'peteish7': 13}\n" - ] - } - ], - "execution_count": 2 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-05T19:02:41.764212Z", - "start_time": "2025-01-05T19:02:26.995717Z" - } - }, - "cell_type": "code", - "source": [ - "# print available metrics\n", - "\n", - "for name, es in exps.items():\n", - " metrics = set()\n", - " for exp in es:\n", - " for summary in exp.get_metrics_summary():\n", - " metrics.add(summary[\"name\"])\n", - " metrics = list(metrics)\n", - " metrics.sort()\n", - "\n", - " print(f\"{name}:\")\n", - " for metric in metrics:\n", - " print(\"\\t\", metric)" - ], - "id": "dc7c5e3c92741b89", - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "peteish32:\n", - "\t data/sequence length\n", - "\t eval/downstream/arc_challenge (BPB)\n", - "\t eval/downstream/arc_challenge (CE loss)\n", - "\t eval/downstream/arc_challenge (length-normalized accuracy)\n", - "\t eval/downstream/arc_challenge (log soft loss)\n", - "\t eval/downstream/arc_challenge (soft loss)\n", - "\t eval/downstream/arc_challenge_rc_5shot (BPB)\n", - "\t eval/downstream/arc_challenge_rc_5shot (CE loss)\n", - "\t eval/downstream/arc_challenge_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/arc_challenge_rc_5shot (log soft loss)\n", - "\t eval/downstream/arc_challenge_rc_5shot (soft loss)\n", - "\t eval/downstream/arc_challenge_test_mc_5shot (BPB)\n", - "\t eval/downstream/arc_challenge_test_mc_5shot (CE loss)\n", - "\t eval/downstream/arc_challenge_test_mc_5shot (accuracy)\n", - "\t eval/downstream/arc_challenge_test_mc_5shot (log soft loss)\n", - "\t eval/downstream/arc_challenge_test_mc_5shot (soft loss)\n", - "\t eval/downstream/arc_challenge_test_rc_5shot (BPB)\n", - "\t eval/downstream/arc_challenge_test_rc_5shot (CE loss)\n", - "\t eval/downstream/arc_challenge_test_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/arc_challenge_test_rc_5shot (log soft loss)\n", - "\t eval/downstream/arc_challenge_test_rc_5shot (soft loss)\n", - "\t eval/downstream/arc_challenge_val_mc_5shot (BPB)\n", - "\t eval/downstream/arc_challenge_val_mc_5shot (CE loss)\n", - "\t eval/downstream/arc_challenge_val_mc_5shot (accuracy)\n", - "\t eval/downstream/arc_challenge_val_mc_5shot (log soft loss)\n", - "\t eval/downstream/arc_challenge_val_mc_5shot (soft loss)\n", - "\t eval/downstream/arc_challenge_val_rc_5shot (BPB)\n", - "\t eval/downstream/arc_challenge_val_rc_5shot (CE loss)\n", - "\t eval/downstream/arc_challenge_val_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/arc_challenge_val_rc_5shot (log soft loss)\n", - "\t eval/downstream/arc_challenge_val_rc_5shot (soft loss)\n", - "\t eval/downstream/arc_easy (BPB)\n", - "\t eval/downstream/arc_easy (CE loss)\n", - "\t eval/downstream/arc_easy (accuracy)\n", - "\t eval/downstream/arc_easy (log soft loss)\n", - "\t eval/downstream/arc_easy (soft loss)\n", - "\t eval/downstream/arc_easy_rc_5shot (BPB)\n", - "\t eval/downstream/arc_easy_rc_5shot (CE loss)\n", - "\t eval/downstream/arc_easy_rc_5shot (accuracy)\n", - "\t eval/downstream/arc_easy_rc_5shot (log soft loss)\n", - "\t eval/downstream/arc_easy_rc_5shot (soft loss)\n", - "\t eval/downstream/arc_easy_test_mc_5shot (BPB)\n", - "\t eval/downstream/arc_easy_test_mc_5shot (CE loss)\n", - "\t eval/downstream/arc_easy_test_mc_5shot (accuracy)\n", - "\t eval/downstream/arc_easy_test_mc_5shot (log soft loss)\n", - "\t eval/downstream/arc_easy_test_mc_5shot (soft loss)\n", - "\t eval/downstream/arc_easy_test_rc_5shot (BPB)\n", - "\t eval/downstream/arc_easy_test_rc_5shot (CE loss)\n", - "\t eval/downstream/arc_easy_test_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/arc_easy_test_rc_5shot (log soft loss)\n", - "\t eval/downstream/arc_easy_test_rc_5shot (soft loss)\n", - "\t eval/downstream/arc_easy_val_mc_5shot (BPB)\n", - "\t eval/downstream/arc_easy_val_mc_5shot (CE loss)\n", - "\t eval/downstream/arc_easy_val_mc_5shot (accuracy)\n", - "\t eval/downstream/arc_easy_val_mc_5shot (log soft loss)\n", - "\t eval/downstream/arc_easy_val_mc_5shot (soft loss)\n", - "\t eval/downstream/arc_easy_val_rc_5shot (BPB)\n", - "\t eval/downstream/arc_easy_val_rc_5shot (CE loss)\n", - "\t eval/downstream/arc_easy_val_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/arc_easy_val_rc_5shot (log soft loss)\n", - "\t eval/downstream/arc_easy_val_rc_5shot (soft loss)\n", - "\t eval/downstream/basic_arithmetic (BPB)\n", - "\t eval/downstream/basic_arithmetic (CE loss)\n", - "\t eval/downstream/basic_arithmetic (accuracy)\n", - "\t eval/downstream/basic_arithmetic (log soft loss)\n", - "\t eval/downstream/basic_arithmetic (soft loss)\n", - "\t eval/downstream/boolq (BPB)\n", - "\t eval/downstream/boolq (CE loss)\n", - "\t eval/downstream/boolq (accuracy)\n", - "\t eval/downstream/boolq (log soft loss)\n", - "\t eval/downstream/boolq (soft loss)\n", - "\t eval/downstream/boolq_val_mc_5shot (BPB)\n", - "\t eval/downstream/boolq_val_mc_5shot (CE loss)\n", - "\t eval/downstream/boolq_val_mc_5shot (accuracy)\n", - "\t eval/downstream/boolq_val_mc_5shot (log soft loss)\n", - "\t eval/downstream/boolq_val_mc_5shot (soft loss)\n", - "\t eval/downstream/boolq_val_rc_5shot (BPB)\n", - "\t eval/downstream/boolq_val_rc_5shot (CE loss)\n", - "\t eval/downstream/boolq_val_rc_5shot (accuracy)\n", - "\t eval/downstream/boolq_val_rc_5shot (log soft loss)\n", - "\t eval/downstream/boolq_val_rc_5shot (soft loss)\n", - "\t eval/downstream/commonsense_qa (BPB)\n", - "\t eval/downstream/commonsense_qa (CE loss)\n", - "\t eval/downstream/commonsense_qa (length-normalized accuracy)\n", - "\t eval/downstream/commonsense_qa (log soft loss)\n", - "\t eval/downstream/commonsense_qa (soft loss)\n", - "\t eval/downstream/copa (BPB)\n", - "\t eval/downstream/copa (CE loss)\n", - "\t eval/downstream/copa (accuracy)\n", - "\t eval/downstream/copa (log soft loss)\n", - "\t eval/downstream/copa (soft loss)\n", - "\t eval/downstream/csqa_rc_5shot (BPB)\n", - "\t eval/downstream/csqa_rc_5shot (CE loss)\n", - "\t eval/downstream/csqa_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/csqa_rc_5shot (log soft loss)\n", - "\t eval/downstream/csqa_rc_5shot (soft loss)\n", - "\t eval/downstream/csqa_val_mc_5shot (BPB)\n", - "\t eval/downstream/csqa_val_mc_5shot (CE loss)\n", - "\t eval/downstream/csqa_val_mc_5shot (accuracy)\n", - "\t eval/downstream/csqa_val_mc_5shot (log soft loss)\n", - "\t eval/downstream/csqa_val_mc_5shot (soft loss)\n", - "\t eval/downstream/csqa_val_rc_5shot (BPB)\n", - "\t eval/downstream/csqa_val_rc_5shot (CE loss)\n", - "\t eval/downstream/csqa_val_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/csqa_val_rc_5shot (log soft loss)\n", - "\t eval/downstream/csqa_val_rc_5shot (soft loss)\n", - "\t eval/downstream/hellaswag (BPB)\n", - "\t eval/downstream/hellaswag (CE loss)\n", - "\t eval/downstream/hellaswag (length-normalized accuracy)\n", - "\t eval/downstream/hellaswag (log soft loss)\n", - "\t eval/downstream/hellaswag (soft loss)\n", - "\t eval/downstream/hellaswag_rc_5shot (BPB)\n", - "\t eval/downstream/hellaswag_rc_5shot (CE loss)\n", - "\t eval/downstream/hellaswag_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/hellaswag_rc_5shot (log soft loss)\n", - "\t eval/downstream/hellaswag_rc_5shot (soft loss)\n", - "\t eval/downstream/hellaswag_val_mc_5shot (BPB)\n", - "\t eval/downstream/hellaswag_val_mc_5shot (CE loss)\n", - "\t eval/downstream/hellaswag_val_mc_5shot (accuracy)\n", - "\t eval/downstream/hellaswag_val_mc_5shot (log soft loss)\n", - "\t eval/downstream/hellaswag_val_mc_5shot (soft loss)\n", - "\t eval/downstream/hellaswag_val_rc_5shot (BPB)\n", - "\t eval/downstream/hellaswag_val_rc_5shot (CE loss)\n", - "\t eval/downstream/hellaswag_val_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/hellaswag_val_rc_5shot (log soft loss)\n", - "\t eval/downstream/hellaswag_val_rc_5shot (soft loss)\n", - "\t eval/downstream/mmlu_humanities_mc_5shot (BPB)\n", - "\t eval/downstream/mmlu_humanities_mc_5shot (CE loss)\n", - "\t eval/downstream/mmlu_humanities_mc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_humanities_mc_5shot (log soft loss)\n", - "\t eval/downstream/mmlu_humanities_mc_5shot (soft loss)\n", - "\t eval/downstream/mmlu_humanities_mc_5shot_test (BPB)\n", - "\t eval/downstream/mmlu_humanities_mc_5shot_test (CE loss)\n", - "\t eval/downstream/mmlu_humanities_mc_5shot_test (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_humanities_mc_5shot_test (log soft loss)\n", - "\t eval/downstream/mmlu_humanities_mc_5shot_test (soft loss)\n", - "\t eval/downstream/mmlu_humanities_val_mc_5shot (BPB)\n", - "\t eval/downstream/mmlu_humanities_val_mc_5shot (CE loss)\n", - "\t eval/downstream/mmlu_humanities_val_mc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_humanities_val_mc_5shot (log soft loss)\n", - "\t eval/downstream/mmlu_humanities_val_mc_5shot (soft loss)\n", - "\t eval/downstream/mmlu_humanities_val_rc_5shot (BPB)\n", - "\t eval/downstream/mmlu_humanities_val_rc_5shot (CE loss)\n", - "\t eval/downstream/mmlu_humanities_val_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_humanities_val_rc_5shot (log soft loss)\n", - "\t eval/downstream/mmlu_humanities_val_rc_5shot (soft loss)\n", - "\t eval/downstream/mmlu_other_mc_5shot (BPB)\n", - "\t eval/downstream/mmlu_other_mc_5shot (CE loss)\n", - "\t eval/downstream/mmlu_other_mc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_other_mc_5shot (log soft loss)\n", - "\t eval/downstream/mmlu_other_mc_5shot (soft loss)\n", - "\t eval/downstream/mmlu_other_mc_5shot_test (BPB)\n", - "\t eval/downstream/mmlu_other_mc_5shot_test (CE loss)\n", - "\t eval/downstream/mmlu_other_mc_5shot_test (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_other_mc_5shot_test (log soft loss)\n", - "\t eval/downstream/mmlu_other_mc_5shot_test (soft loss)\n", - "\t eval/downstream/mmlu_other_val_mc_5shot (BPB)\n", - "\t eval/downstream/mmlu_other_val_mc_5shot (CE loss)\n", - "\t eval/downstream/mmlu_other_val_mc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_other_val_mc_5shot (log soft loss)\n", - "\t eval/downstream/mmlu_other_val_mc_5shot (soft loss)\n", - "\t eval/downstream/mmlu_other_val_rc_5shot (BPB)\n", - "\t eval/downstream/mmlu_other_val_rc_5shot (CE loss)\n", - "\t eval/downstream/mmlu_other_val_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_other_val_rc_5shot (log soft loss)\n", - "\t eval/downstream/mmlu_other_val_rc_5shot (soft loss)\n", - "\t eval/downstream/mmlu_social_sciences_mc_5shot (BPB)\n", - "\t eval/downstream/mmlu_social_sciences_mc_5shot (CE loss)\n", - "\t eval/downstream/mmlu_social_sciences_mc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_social_sciences_mc_5shot (log soft loss)\n", - "\t eval/downstream/mmlu_social_sciences_mc_5shot (soft loss)\n", - "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (BPB)\n", - "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (CE loss)\n", - "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (log soft loss)\n", - "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (soft loss)\n", - "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (BPB)\n", - "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (CE loss)\n", - "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (log soft loss)\n", - "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (soft loss)\n", - "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (BPB)\n", - "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (CE loss)\n", - "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (log soft loss)\n", - "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (soft loss)\n", - "\t eval/downstream/mmlu_stem_mc_5shot (BPB)\n", - "\t eval/downstream/mmlu_stem_mc_5shot (CE loss)\n", - "\t eval/downstream/mmlu_stem_mc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_stem_mc_5shot (log soft loss)\n", - "\t eval/downstream/mmlu_stem_mc_5shot (soft loss)\n", - "\t eval/downstream/mmlu_stem_mc_5shot_test (BPB)\n", - "\t eval/downstream/mmlu_stem_mc_5shot_test (CE loss)\n", - "\t eval/downstream/mmlu_stem_mc_5shot_test (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_stem_mc_5shot_test (log soft loss)\n", - "\t eval/downstream/mmlu_stem_mc_5shot_test (soft loss)\n", - "\t eval/downstream/mmlu_stem_val_mc_5shot (BPB)\n", - "\t eval/downstream/mmlu_stem_val_mc_5shot (CE loss)\n", - "\t eval/downstream/mmlu_stem_val_mc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_stem_val_mc_5shot (log soft loss)\n", - "\t eval/downstream/mmlu_stem_val_mc_5shot (soft loss)\n", - "\t eval/downstream/mmlu_stem_val_rc_5shot (BPB)\n", - "\t eval/downstream/mmlu_stem_val_rc_5shot (CE loss)\n", - "\t eval/downstream/mmlu_stem_val_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_stem_val_rc_5shot (log soft loss)\n", - "\t eval/downstream/mmlu_stem_val_rc_5shot (soft loss)\n", - "\t eval/downstream/openbook_qa (BPB)\n", - "\t eval/downstream/openbook_qa (CE loss)\n", - "\t eval/downstream/openbook_qa (length-normalized accuracy)\n", - "\t eval/downstream/openbook_qa (log soft loss)\n", - "\t eval/downstream/openbook_qa (soft loss)\n", - "\t eval/downstream/openbookqa_rc_5shot (BPB)\n", - "\t eval/downstream/openbookqa_rc_5shot (CE loss)\n", - "\t eval/downstream/openbookqa_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/openbookqa_rc_5shot (log soft loss)\n", - "\t eval/downstream/openbookqa_rc_5shot (soft loss)\n", - "\t eval/downstream/openbookqa_test_mc_5shot (BPB)\n", - "\t eval/downstream/openbookqa_test_mc_5shot (CE loss)\n", - "\t eval/downstream/openbookqa_test_mc_5shot (accuracy)\n", - "\t eval/downstream/openbookqa_test_mc_5shot (log soft loss)\n", - "\t eval/downstream/openbookqa_test_mc_5shot (soft loss)\n", - "\t eval/downstream/openbookqa_test_rc_5shot (BPB)\n", - "\t eval/downstream/openbookqa_test_rc_5shot (CE loss)\n", - "\t eval/downstream/openbookqa_test_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/openbookqa_test_rc_5shot (log soft loss)\n", - "\t eval/downstream/openbookqa_test_rc_5shot (soft loss)\n", - "\t eval/downstream/openbookqa_val_mc_5shot (BPB)\n", - "\t eval/downstream/openbookqa_val_mc_5shot (CE loss)\n", - "\t eval/downstream/openbookqa_val_mc_5shot (accuracy)\n", - "\t eval/downstream/openbookqa_val_mc_5shot (log soft loss)\n", - "\t eval/downstream/openbookqa_val_mc_5shot (soft loss)\n", - "\t eval/downstream/openbookqa_val_rc_5shot (BPB)\n", - "\t eval/downstream/openbookqa_val_rc_5shot (CE loss)\n", - "\t eval/downstream/openbookqa_val_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/openbookqa_val_rc_5shot (log soft loss)\n", - "\t eval/downstream/openbookqa_val_rc_5shot (soft loss)\n", - "\t eval/downstream/piqa (BPB)\n", - "\t eval/downstream/piqa (CE loss)\n", - "\t eval/downstream/piqa (length-normalized accuracy)\n", - "\t eval/downstream/piqa (log soft loss)\n", - "\t eval/downstream/piqa (soft loss)\n", - "\t eval/downstream/piqa_rc_5shot (BPB)\n", - "\t eval/downstream/piqa_rc_5shot (CE loss)\n", - "\t eval/downstream/piqa_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/piqa_rc_5shot (log soft loss)\n", - "\t eval/downstream/piqa_rc_5shot (soft loss)\n", - "\t eval/downstream/piqa_val_mc_5shot (BPB)\n", - "\t eval/downstream/piqa_val_mc_5shot (CE loss)\n", - "\t eval/downstream/piqa_val_mc_5shot (accuracy)\n", - "\t eval/downstream/piqa_val_mc_5shot (log soft loss)\n", - "\t eval/downstream/piqa_val_mc_5shot (soft loss)\n", - "\t eval/downstream/piqa_val_rc_5shot (BPB)\n", - "\t eval/downstream/piqa_val_rc_5shot (CE loss)\n", - "\t eval/downstream/piqa_val_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/piqa_val_rc_5shot (log soft loss)\n", - "\t eval/downstream/piqa_val_rc_5shot (soft loss)\n", - "\t eval/downstream/sciq (BPB)\n", - "\t eval/downstream/sciq (CE loss)\n", - "\t eval/downstream/sciq (accuracy)\n", - "\t eval/downstream/sciq (log soft loss)\n", - "\t eval/downstream/sciq (soft loss)\n", - "\t eval/downstream/social_iqa (BPB)\n", - "\t eval/downstream/social_iqa (CE loss)\n", - "\t eval/downstream/social_iqa (length-normalized accuracy)\n", - "\t eval/downstream/social_iqa (log soft loss)\n", - "\t eval/downstream/social_iqa (soft loss)\n", - "\t eval/downstream/socialiqa_rc_5shot (BPB)\n", - "\t eval/downstream/socialiqa_rc_5shot (CE loss)\n", - "\t eval/downstream/socialiqa_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/socialiqa_rc_5shot (log soft loss)\n", - "\t eval/downstream/socialiqa_rc_5shot (soft loss)\n", - "\t eval/downstream/socialiqa_val_mc_5shot (BPB)\n", - "\t eval/downstream/socialiqa_val_mc_5shot (CE loss)\n", - "\t eval/downstream/socialiqa_val_mc_5shot (accuracy)\n", - "\t eval/downstream/socialiqa_val_mc_5shot (log soft loss)\n", - "\t eval/downstream/socialiqa_val_mc_5shot (soft loss)\n", - "\t eval/downstream/socialiqa_val_rc_5shot (BPB)\n", - "\t eval/downstream/socialiqa_val_rc_5shot (CE loss)\n", - "\t eval/downstream/socialiqa_val_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/socialiqa_val_rc_5shot (log soft loss)\n", - "\t eval/downstream/socialiqa_val_rc_5shot (soft loss)\n", - "\t eval/downstream/winogrande (BPB)\n", - "\t eval/downstream/winogrande (CE loss)\n", - "\t eval/downstream/winogrande (accuracy)\n", - "\t eval/downstream/winogrande (log soft loss)\n", - "\t eval/downstream/winogrande (soft loss)\n", - "\t eval/downstream/winogrande_rc_5shot (BPB)\n", - "\t eval/downstream/winogrande_rc_5shot (CE loss)\n", - "\t eval/downstream/winogrande_rc_5shot (accuracy)\n", - "\t eval/downstream/winogrande_rc_5shot (log soft loss)\n", - "\t eval/downstream/winogrande_rc_5shot (soft loss)\n", - "\t eval/downstream/winogrande_val_mc_5shot (BPB)\n", - "\t eval/downstream/winogrande_val_mc_5shot (CE loss)\n", - "\t eval/downstream/winogrande_val_mc_5shot (accuracy)\n", - "\t eval/downstream/winogrande_val_mc_5shot (log soft loss)\n", - "\t eval/downstream/winogrande_val_mc_5shot (soft loss)\n", - "\t eval/downstream/winogrande_val_rc_5shot (BPB)\n", - "\t eval/downstream/winogrande_val_rc_5shot (CE loss)\n", - "\t eval/downstream/winogrande_val_rc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/winogrande_val_rc_5shot (log soft loss)\n", - "\t eval/downstream/winogrande_val_rc_5shot (soft loss)\n", - "\t eval/lm/c4_en-validation/CE loss\n", - "\t eval/lm/c4_en-validation/PPL\n", - "\t eval/lm/dolma_books-validation/CE loss\n", - "\t eval/lm/dolma_books-validation/PPL\n", - "\t eval/lm/dolma_common-crawl-validation/CE loss\n", - "\t eval/lm/dolma_common-crawl-validation/PPL\n", - "\t eval/lm/dolma_pes2o-validation/CE loss\n", - "\t eval/lm/dolma_pes2o-validation/PPL\n", - "\t eval/lm/dolma_reddit-validation/CE loss\n", - "\t eval/lm/dolma_reddit-validation/PPL\n", - "\t eval/lm/dolma_stack-validation/CE loss\n", - "\t eval/lm/dolma_stack-validation/PPL\n", - "\t eval/lm/dolma_wiki-validation/CE loss\n", - "\t eval/lm/dolma_wiki-validation/PPL\n", - "\t eval/lm/ice-validation/CE loss\n", - "\t eval/lm/ice-validation/PPL\n", - "\t eval/lm/m2d2_s2orc-validation/CE loss\n", - "\t eval/lm/m2d2_s2orc-validation/PPL\n", - "\t eval/lm/pile-validation/CE loss\n", - "\t eval/lm/pile-validation/PPL\n", - "\t eval/lm/wikitext_103-validation/CE loss\n", - "\t eval/lm/wikitext_103-validation/PPL\n", - "\t optim/LR (group 0)\n", - "\t optim/LR (group 1)\n", - "\t optim/step skipped\n", - "\t optim/total grad norm\n", - "\t sys.compute.overall\n", - "\t sys.compute.utilized\n", - "\t sys.cpu.percent.avg\n", - "\t sys.disk.read_bps\n", - "\t sys.disk.root.percent.used\n", - "\t sys.disk.root.used\n", - "\t sys.disk.write_bps\n", - "\t sys.gpu.0.free_memory\n", - "\t sys.gpu.0.gpu_utilization\n", - "\t sys.gpu.0.memory_utilization\n", - "\t sys.gpu.0.percent.used_memory\n", - "\t sys.gpu.0.power_usage\n", - "\t sys.gpu.0.temperature\n", - "\t sys.gpu.0.total_memory\n", - "\t sys.gpu.0.used_memory\n", - "\t sys.gpu.1.free_memory\n", - "\t sys.gpu.1.gpu_utilization\n", - "\t sys.gpu.1.memory_utilization\n", - "\t sys.gpu.1.percent.used_memory\n", - "\t sys.gpu.1.power_usage\n", - "\t sys.gpu.1.temperature\n", - "\t sys.gpu.1.total_memory\n", - "\t sys.gpu.1.used_memory\n", - "\t sys.gpu.2.free_memory\n", - "\t sys.gpu.2.gpu_utilization\n", - "\t sys.gpu.2.memory_utilization\n", - "\t sys.gpu.2.percent.used_memory\n", - "\t sys.gpu.2.power_usage\n", - "\t sys.gpu.2.temperature\n", - "\t sys.gpu.2.total_memory\n", - "\t sys.gpu.2.used_memory\n", - "\t sys.gpu.3.free_memory\n", - "\t sys.gpu.3.gpu_utilization\n", - "\t sys.gpu.3.memory_utilization\n", - "\t sys.gpu.3.percent.used_memory\n", - "\t sys.gpu.3.power_usage\n", - "\t sys.gpu.3.temperature\n", - "\t sys.gpu.3.total_memory\n", - "\t sys.gpu.3.used_memory\n", - "\t sys.gpu.4.free_memory\n", - "\t sys.gpu.4.gpu_utilization\n", - "\t sys.gpu.4.memory_utilization\n", - "\t sys.gpu.4.percent.used_memory\n", - "\t sys.gpu.4.power_usage\n", - "\t sys.gpu.4.temperature\n", - "\t sys.gpu.4.total_memory\n", - "\t sys.gpu.4.used_memory\n", - "\t sys.gpu.5.free_memory\n", - "\t sys.gpu.5.gpu_utilization\n", - "\t sys.gpu.5.memory_utilization\n", - "\t sys.gpu.5.percent.used_memory\n", - "\t sys.gpu.5.power_usage\n", - "\t sys.gpu.5.temperature\n", - "\t sys.gpu.5.total_memory\n", - "\t sys.gpu.5.used_memory\n", - "\t sys.gpu.6.free_memory\n", - "\t sys.gpu.6.gpu_utilization\n", - "\t sys.gpu.6.memory_utilization\n", - "\t sys.gpu.6.percent.used_memory\n", - "\t sys.gpu.6.power_usage\n", - "\t sys.gpu.6.temperature\n", - "\t sys.gpu.6.total_memory\n", - "\t sys.gpu.6.used_memory\n", - "\t sys.gpu.7.free_memory\n", - "\t sys.gpu.7.gpu_utilization\n", - "\t sys.gpu.7.memory_utilization\n", - "\t sys.gpu.7.percent.used_memory\n", - "\t sys.gpu.7.power_usage\n", - "\t sys.gpu.7.temperature\n", - "\t sys.gpu.7.total_memory\n", - "\t sys.gpu.7.used_memory\n", - "\t sys.load.avg\n", - "\t sys.network.receive_bps\n", - "\t sys.network.send_bps\n", - "\t sys.ram.available\n", - "\t sys.ram.percent.used\n", - "\t sys.ram.total\n", - "\t sys.ram.used\n", - "\t system/GPU active mem (%)\n", - "\t system/GPU active mem (GiB)\n", - "\t system/GPU reserved mem (%)\n", - "\t system/GPU reserved mem (GiB)\n", - "\t throughput/device/BPS\n", - "\t throughput/device/BPS (actual avg)\n", - "\t throughput/device/TPS\n", - "\t throughput/device/TPS (actual avg)\n", - "\t throughput/device/data loading (%)\n", - "\t throughput/device/data loading (s)\n", - "\t throughput/total tokens\n", - "\t train/CE loss\n", - "\t train/PPL\n", - "\t train/Z loss\n", - "peteish13:\n", - "\t eval/downstream/arc_challenge (length-normalized accuracy)\n", - "\t eval/downstream/arc_easy (accuracy)\n", - "\t eval/downstream/basic_arithmetic (accuracy)\n", - "\t eval/downstream/boolq (accuracy)\n", - "\t eval/downstream/commonsense_qa (length-normalized accuracy)\n", - "\t eval/downstream/copa (accuracy)\n", - "\t eval/downstream/hellaswag (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_humanities_mc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_humanities_mc_5shot_test (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_humanities_var (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_other_mc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_other_mc_5shot_test (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_other_var (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_social_sciences_mc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_social_sciences_var (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_stem_mc_5shot (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_stem_mc_5shot_test (length-normalized accuracy)\n", - "\t eval/downstream/mmlu_stem_var (length-normalized accuracy)\n", - "\t eval/downstream/openbook_qa (length-normalized accuracy)\n", - "\t eval/downstream/piqa (length-normalized accuracy)\n", - "\t eval/downstream/sciq (accuracy)\n", - "\t eval/downstream/social_iqa (length-normalized accuracy)\n", - "\t eval/downstream/winogrande (accuracy)\n", - "\t optim/LR (group 0)\n", - "\t optim/LR (group 1)\n", - "\t optim/total grad norm\n", - "\t sys.compute.overall\n", - "\t sys.compute.utilized\n", - "\t sys.cpu.percent.avg\n", - "\t sys.disk.read_bps\n", - "\t sys.disk.root.percent.used\n", - "\t sys.disk.root.used\n", - "\t sys.disk.write_bps\n", - "\t sys.gpu.0.free_memory\n", - "\t sys.gpu.0.gpu_utilization\n", - "\t sys.gpu.0.memory_utilization\n", - "\t sys.gpu.0.percent.used_memory\n", - "\t sys.gpu.0.power_usage\n", - "\t sys.gpu.0.temperature\n", - "\t sys.gpu.0.total_memory\n", - "\t sys.gpu.0.used_memory\n", - "\t sys.gpu.1.free_memory\n", - "\t sys.gpu.1.gpu_utilization\n", - "\t sys.gpu.1.memory_utilization\n", - "\t sys.gpu.1.percent.used_memory\n", - "\t sys.gpu.1.power_usage\n", - "\t sys.gpu.1.temperature\n", - "\t sys.gpu.1.total_memory\n", - "\t sys.gpu.1.used_memory\n", - "\t sys.gpu.2.free_memory\n", - "\t sys.gpu.2.gpu_utilization\n", - "\t sys.gpu.2.memory_utilization\n", - "\t sys.gpu.2.percent.used_memory\n", - "\t sys.gpu.2.power_usage\n", - "\t sys.gpu.2.temperature\n", - "\t sys.gpu.2.total_memory\n", - "\t sys.gpu.2.used_memory\n", - "\t sys.gpu.3.free_memory\n", - "\t sys.gpu.3.gpu_utilization\n", - "\t sys.gpu.3.memory_utilization\n", - "\t sys.gpu.3.percent.used_memory\n", - "\t sys.gpu.3.power_usage\n", - "\t sys.gpu.3.temperature\n", - "\t sys.gpu.3.total_memory\n", - "\t sys.gpu.3.used_memory\n", - "\t sys.load.avg\n", - "\t sys.network.receive_bps\n", - "\t sys.network.send_bps\n", - "\t sys.ram.available\n", - "\t sys.ram.percent.used\n", - "\t sys.ram.total\n", - "\t sys.ram.used\n", - "\t throughput/device/BPS\n", - "\t throughput/device/TPS\n", - "\t train/CE loss\n", - "\t train/PPL\n", - "\t train/Z loss\n", - "peteish7:\n", - "\t optim/LR (group 0)\n", - "\t optim/LR (group 1)\n", - "\t optim/total grad norm\n", - "\t sys.compute.overall\n", - "\t sys.compute.utilized\n", - "\t sys.cpu.percent.avg\n", - "\t sys.disk.read_bps\n", - "\t sys.disk.root.percent.used\n", - "\t sys.disk.root.used\n", - "\t sys.disk.write_bps\n", - "\t sys.load.avg\n", - "\t sys.network.receive_bps\n", - "\t sys.network.send_bps\n", - "\t sys.ram.available\n", - "\t sys.ram.percent.used\n", - "\t sys.ram.total\n", - "\t sys.ram.used\n", - "\t throughput/device/BPS\n", - "\t throughput/device/TPS\n", - "\t train/CE loss\n", - "\t train/PPL\n", - "\t train/Z loss\n" - ] - } - ], - "execution_count": 3 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-05T19:03:16.729657Z", - "start_time": "2025-01-05T19:02:41.776673Z" - } - }, - "cell_type": "code", - "source": [ - "from tqdm.notebook import tqdm\n", - "\n", - "def download_metric(exps, metric_name):\n", - " result = {}\n", - " for exp in tqdm(exps):\n", - " metrics = exp.get_metrics(metric_name)\n", - " for values in metrics:\n", - " result[values['step']] = float(values['metricValue'])\n", - " result = dict(sorted(result.items()))\n", - " return result\n", - "\n", - "loss = {\n", - " name: download_metric(es, \"train/CE loss\")\n", - " for name, es in exps.items()\n", - "}\n", - "\n", - "skipped_steps = {\n", - " name: download_metric(es, \"optim/step skipped\")\n", - " for name, es in exps.items()\n", - "}\n", - "\n", - "speed = {\n", - " name: download_metric(es, \"train/CE loss\")\n", - " for name, es in exps.items()\n", - "}" - ], - "id": "6aa86a5638253061", - "outputs": [ - { - "data": { - "text/plain": [ - " 0%| | 0/50 [00:00 0])" - ], - "id": "277e0e889edb7b16", - "outputs": [ - { - "data": { - "text/plain": [ - "
" - ], - "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-05T11:07:13.567234\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Steps skipped for the 32B: 47\n", - "[848, 1401, 80788, 81072, 84048, 85129, 87386, 92844, 107316, 111491, 113030, 114230, 118668, 121925, 126863, 127493, 128136, 129747, 134843, 136385, 142362, 142815, 144303, 144548, 147139, 147455, 148216, 148703, 150206, 154267, 159678, 159881, 160407, 163682, 167141, 167784, 175621, 187888, 188783, 194308, 200682, 201311, 204820, 205830, 206617, 211141, 212691]\n" - ] - } - ], - "execution_count": 8 - }, { "metadata": {}, - "cell_type": "markdown", - "source": "## Downstream", - "id": "83cbde8bd1160629" - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-05T19:07:12.880276Z", - "start_time": "2025-01-05T19:03:17.042699Z" - } - }, - "cell_type": "code", + "cell_type": "raw", "source": [ - "aggregate_metric_definitions = {\n", - " \"MMLU 5-shot MC\": {\n", - " \"eval/downstream/mmlu_stem_mc_5shot (length-normalized accuracy)\": 0.215,\n", - " \"eval/downstream/mmlu_humanities_mc_5shot (length-normalized accuracy)\": 0.335,\n", - " \"eval/downstream/mmlu_social_sciences_mc_5shot (length-normalized accuracy)\": 0.219,\n", - " \"eval/downstream/mmlu_other_mc_5shot (length-normalized accuracy)\": 0.231\n", - " },\n", - " \"Average of core 12\": {\n", - " \"eval/downstream/arc_challenge (length-normalized accuracy)\": 1 / 12,\n", - " \"eval/downstream/arc_easy (accuracy)\": 1 / 12,\n", - " \"eval/downstream/basic_arithmetic (accuracy)\": 1 / 12,\n", - " \"eval/downstream/boolq (accuracy)\": 1 / 12,\n", - " \"eval/downstream/commonsense_qa (length-normalized accuracy)\": 1 / 12,\n", - " \"eval/downstream/copa (accuracy)\": 1 / 12,\n", - " \"eval/downstream/hellaswag (length-normalized accuracy)\": 1 / 12,\n", - " \"eval/downstream/openbook_qa (length-normalized accuracy)\": 1 / 12,\n", - " \"eval/downstream/piqa (length-normalized accuracy)\": 1 / 12,\n", - " \"eval/downstream/sciq (accuracy)\": 1 / 12,\n", - " \"eval/downstream/social_iqa (length-normalized accuracy)\": 1 / 12,\n", - " \"eval/downstream/winogrande (accuracy)\": 1 / 12,\n", - " },\n", - " \"Hellswag\": {\n", - " \"eval/downstream/hellaswag (length-normalized accuracy)\": 1\n", + "{\n", + " \"cells\": [\n", + " {\n", + " \"cell_type\": \"code\",\n", + " \"id\": \"initial_id\",\n", + " \"metadata\": {\n", + " \"collapsed\": true,\n", + " \"ExecuteTime\": {\n", + " \"end_time\": \"2025-01-05T19:02:26.034046Z\",\n", + " \"start_time\": \"2025-01-05T19:02:25.324969Z\"\n", " }\n", - "}\n", - "\n", - "import matplotlib.pyplot as plt\n", - "%matplotlib inline\n", - "%config InlineBackend.figure_format = 'svg'\n", - "import numpy as np\n", - "\n", - "fig, axs = plt.subplots(nrows=len(aggregate_metric_definitions), sharex=True, figsize=(10, len(aggregate_metric_definitions)*3))\n", - "\n", - "for ax, agg_metric_name in zip(axs, aggregate_metric_definitions):\n", - " metric_to_weight = aggregate_metric_definitions[agg_metric_name]\n", - " for run_name, run_exps in exps.items():\n", - " metric_to_values = {}\n", - " for metric in metric_to_weight.keys():\n", - " metric_to_values[metric] = download_metric(run_exps, metric)\n", - "\n", - " all_steps = set.union(*[set(v.keys()) for v in metric_to_values.values()])\n", - " minimal_steps = set.intersection(*[set(v.keys()) for v in metric_to_values.values()])\n", - " if all_steps != minimal_steps:\n", - " print(f\"Missing steps for {run_name} / {agg_metric_name}: {all_steps - minimal_steps}\")\n", - "\n", - " aggregated_values = {}\n", - " for step in minimal_steps:\n", - " value = 0.0\n", - " for metric, weight in metric_to_weight.items():\n", - " value += metric_to_values[metric][step] * weight\n", - " aggregated_values[step] = value\n", - " if len(aggregated_values) == 0:\n", - " continue\n", - "\n", - " print(f\"{run_name} / {agg_metric_name} max: {max(aggregated_values.values())}\")\n", - "\n", - " xs = np.array(list(aggregated_values.keys()))\n", - " ys = np.array(list(aggregated_values.values()))\n", - " order = np.argsort(xs)\n", - " xs = xs[order]\n", - " ys = ys[order]\n", - " xs *= (2048 * 4096)\n", - " ax.plot(xs, ys, linewidth=0.5)\n", - " ax.set_ylabel(agg_metric_name)\n", - "\n", - "plt.xlabel(\"step\")\n", - "plt.show()" - ], - "id": "8b310d9cc68ad856", - "outputs": [ - { - "data": { - "text/plain": [ - " 0%| | 0/50 [00:00" - ], - "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-05T11:07:12.858839\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "execution_count": 6 - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "## Spike Analysis", - "id": "744574cd19bbe369" - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-05T19:07:13.429408Z", - "start_time": "2025-01-05T19:07:12.942046Z" - } - }, - "cell_type": "code", - "source": [ - "window_size = 128\n", - "losses = np.array(list(loss[\"peteish32\"].values()))\n", - "steps = np.array(list(loss[\"peteish32\"].keys()))\n", - "\n", - "from numpy.lib.stride_tricks import sliding_window_view\n", - "windows = sliding_window_view(losses, window_size)\n", - "\n", - "stds = windows.std(axis=1)\n", - "means = windows.mean(axis=1)\n", - "losses = losses[window_size - 1 :]\n", - "steps = steps[window_size - 1 :]\n", - "spike_steps = steps[np.argwhere(losses > means + stds * 6)].flatten()\n", - "print(f\"Steps with spikes: {spike_steps}\")\n", - "\n", - "fig, axes = plt.subplots(\n", - " nrows=len(spike_steps),\n", - " figsize=(7, len(spike_steps)*3),\n", - " sharex=False\n", - ")\n", - "\n", - "for ax, spike in zip(axes, spike_steps):\n", - " for name, values in loss.items():\n", - " xs = np.array(list(values.keys()))\n", - " ys = np.array(list(values.values()))\n", - " ax.plot(xs, ys, linewidth=0.5)\n", - " ax.set_ylim(2.1, 2.5)\n", - " ax.set_xlim(spike-1000, spike+1000)\n", - " plt.yscale('log')\n", - " plt.xlabel(\"step\")\n", - " plt.ylabel(\"loss\")\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n" - ], - "id": "6eb5abfb647663a5", - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Steps with spikes: [ 29645 38677 49089 54503 66257 73019 144302]\n" - ] - }, - { - "data": { - "text/plain": [ - "
" - ], - "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-05T11:07:13.365572\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" - }, - "metadata": {}, - "output_type": "display_data" - } + " },\n", + " \"source\": [\n", + " \"import os\\n\",\n", + " \"from comet_ml.api import API\\n\",\n", + " \"\\n\",\n", + " \"comet_api = API(os.environ[\\\"COMETML_API_KEY\\\"])\\n\"\n", + " ],\n", + " \"outputs\": [],\n", + " \"execution_count\": 1\n", + " },\n", + " {\n", + " \"metadata\": {\n", + " \"ExecuteTime\": {\n", + " \"end_time\": \"2025-01-05T19:02:26.926559Z\",\n", + " \"start_time\": \"2025-01-05T19:02:26.046253Z\"\n", + " }\n", + " },\n", + " \"cell_type\": \"code\",\n", + " \"source\": [\n", + " \"exps = {\\n\",\n", + " \" \\\"peteish32\\\": comet_api.get_experiments(\\\"ai2\\\", \\\"peteish32\\\", \\\"peteish32\\\"),\\n\",\n", + " \" \\\"peteish13\\\": comet_api.get_experiments(\\\"ai2\\\", \\\"olmo-2-1124-13b\\\", \\\"OLMo-2-1124-13B-stage-1\\\"),\\n\",\n", + " \" \\\"peteish7\\\": comet_api.get_experiments(\\\"ai2\\\", \\\"olmo-core-7b\\\", \\\"peteish7\\\")\\n\",\n", + " \"}\\n\",\n", + " \"\\n\",\n", + " \"print(repr({k: len(v) for k, v in exps.items()}))\"\n", + " ],\n", + " \"id\": \"2c17abe415dabf07\",\n", + " \"outputs\": [\n", + " {\n", + " \"name\": \"stdout\",\n", + " \"output_type\": \"stream\",\n", + " \"text\": [\n", + " \"{'peteish32': 50, 'peteish13': 75, 'peteish7': 13}\\n\"\n", + " ]\n", + " }\n", + " ],\n", + " \"execution_count\": 2\n", + " },\n", + " {\n", + " \"metadata\": {\n", + " \"ExecuteTime\": {\n", + " \"end_time\": \"2025-01-05T19:02:41.764212Z\",\n", + " \"start_time\": \"2025-01-05T19:02:26.995717Z\"\n", + " }\n", + " },\n", + " \"cell_type\": \"code\",\n", + " \"source\": [\n", + " \"# print available metrics\\n\",\n", + " \"\\n\",\n", + " \"for name, es in exps.items():\\n\",\n", + " \" metrics = set()\\n\",\n", + " \" for exp in es:\\n\",\n", + " \" for summary in exp.get_metrics_summary():\\n\",\n", + " \" metrics.add(summary[\\\"name\\\"])\\n\",\n", + " \" metrics = list(metrics)\\n\",\n", + " \" metrics.sort()\\n\",\n", + " \"\\n\",\n", + " \" print(f\\\"{name}:\\\")\\n\",\n", + " \" for metric in metrics:\\n\",\n", + " \" print(\\\"\\\\t\\\", metric)\"\n", + " ],\n", + " \"id\": \"dc7c5e3c92741b89\",\n", + " \"outputs\": [\n", + " {\n", + " \"name\": \"stdout\",\n", + " \"output_type\": \"stream\",\n", + " \"text\": [\n", + " \"peteish32:\\n\",\n", + " \"\\t data/sequence length\\n\",\n", + " \"\\t eval/downstream/arc_challenge (BPB)\\n\",\n", + " \"\\t eval/downstream/arc_challenge (CE loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/arc_challenge (log soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge (soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_test_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_test_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_test_mc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_test_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_test_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_test_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_test_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_test_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_test_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_test_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_val_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_val_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_val_mc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_val_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_val_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_val_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_val_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_val_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_val_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_challenge_val_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy (BPB)\\n\",\n", + " \"\\t eval/downstream/arc_easy (CE loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy (accuracy)\\n\",\n", + " \"\\t eval/downstream/arc_easy (log soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy (soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/arc_easy_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_rc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/arc_easy_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_test_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/arc_easy_test_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_test_mc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/arc_easy_test_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_test_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_test_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/arc_easy_test_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_test_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/arc_easy_test_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_test_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_val_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/arc_easy_val_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_val_mc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/arc_easy_val_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_val_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_val_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/arc_easy_val_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_val_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/arc_easy_val_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/arc_easy_val_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/basic_arithmetic (BPB)\\n\",\n", + " \"\\t eval/downstream/basic_arithmetic (CE loss)\\n\",\n", + " \"\\t eval/downstream/basic_arithmetic (accuracy)\\n\",\n", + " \"\\t eval/downstream/basic_arithmetic (log soft loss)\\n\",\n", + " \"\\t eval/downstream/basic_arithmetic (soft loss)\\n\",\n", + " \"\\t eval/downstream/boolq (BPB)\\n\",\n", + " \"\\t eval/downstream/boolq (CE loss)\\n\",\n", + " \"\\t eval/downstream/boolq (accuracy)\\n\",\n", + " \"\\t eval/downstream/boolq (log soft loss)\\n\",\n", + " \"\\t eval/downstream/boolq (soft loss)\\n\",\n", + " \"\\t eval/downstream/boolq_val_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/boolq_val_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/boolq_val_mc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/boolq_val_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/boolq_val_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/boolq_val_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/boolq_val_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/boolq_val_rc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/boolq_val_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/boolq_val_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/commonsense_qa (BPB)\\n\",\n", + " \"\\t eval/downstream/commonsense_qa (CE loss)\\n\",\n", + " \"\\t eval/downstream/commonsense_qa (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/commonsense_qa (log soft loss)\\n\",\n", + " \"\\t eval/downstream/commonsense_qa (soft loss)\\n\",\n", + " \"\\t eval/downstream/copa (BPB)\\n\",\n", + " \"\\t eval/downstream/copa (CE loss)\\n\",\n", + " \"\\t eval/downstream/copa (accuracy)\\n\",\n", + " \"\\t eval/downstream/copa (log soft loss)\\n\",\n", + " \"\\t eval/downstream/copa (soft loss)\\n\",\n", + " \"\\t eval/downstream/csqa_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/csqa_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/csqa_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/csqa_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/csqa_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/csqa_val_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/csqa_val_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/csqa_val_mc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/csqa_val_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/csqa_val_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/csqa_val_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/csqa_val_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/csqa_val_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/csqa_val_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/csqa_val_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/hellaswag (BPB)\\n\",\n", + " \"\\t eval/downstream/hellaswag (CE loss)\\n\",\n", + " \"\\t eval/downstream/hellaswag (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/hellaswag (log soft loss)\\n\",\n", + " \"\\t eval/downstream/hellaswag (soft loss)\\n\",\n", + " \"\\t eval/downstream/hellaswag_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/hellaswag_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/hellaswag_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/hellaswag_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/hellaswag_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/hellaswag_val_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/hellaswag_val_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/hellaswag_val_mc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/hellaswag_val_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/hellaswag_val_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/hellaswag_val_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/hellaswag_val_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/hellaswag_val_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/hellaswag_val_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/hellaswag_val_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_mc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_mc_5shot_test (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_mc_5shot_test (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_mc_5shot_test (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_mc_5shot_test (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_mc_5shot_test (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_val_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_val_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_val_mc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_val_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_val_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_val_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_val_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_val_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_val_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_val_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_mc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_mc_5shot_test (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_mc_5shot_test (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_mc_5shot_test (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_mc_5shot_test (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_mc_5shot_test (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_val_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_val_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_val_mc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_val_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_val_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_val_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_val_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_val_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_val_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_val_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_mc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_mc_5shot_test (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_mc_5shot_test (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_mc_5shot_test (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_mc_5shot_test (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_mc_5shot_test (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_val_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_val_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_val_mc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_val_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_val_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_val_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_val_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_val_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_val_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_val_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_mc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_mc_5shot_test (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_mc_5shot_test (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_mc_5shot_test (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_mc_5shot_test (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_mc_5shot_test (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_val_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_val_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_val_mc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_val_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_val_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_val_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_val_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_val_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_val_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_val_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/openbook_qa (BPB)\\n\",\n", + " \"\\t eval/downstream/openbook_qa (CE loss)\\n\",\n", + " \"\\t eval/downstream/openbook_qa (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/openbook_qa (log soft loss)\\n\",\n", + " \"\\t eval/downstream/openbook_qa (soft loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/openbookqa_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/openbookqa_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_test_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/openbookqa_test_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_test_mc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/openbookqa_test_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_test_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_test_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/openbookqa_test_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_test_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/openbookqa_test_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_test_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_val_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/openbookqa_val_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_val_mc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/openbookqa_val_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_val_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_val_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/openbookqa_val_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_val_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/openbookqa_val_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/openbookqa_val_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/piqa (BPB)\\n\",\n", + " \"\\t eval/downstream/piqa (CE loss)\\n\",\n", + " \"\\t eval/downstream/piqa (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/piqa (log soft loss)\\n\",\n", + " \"\\t eval/downstream/piqa (soft loss)\\n\",\n", + " \"\\t eval/downstream/piqa_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/piqa_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/piqa_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/piqa_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/piqa_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/piqa_val_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/piqa_val_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/piqa_val_mc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/piqa_val_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/piqa_val_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/piqa_val_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/piqa_val_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/piqa_val_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/piqa_val_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/piqa_val_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/sciq (BPB)\\n\",\n", + " \"\\t eval/downstream/sciq (CE loss)\\n\",\n", + " \"\\t eval/downstream/sciq (accuracy)\\n\",\n", + " \"\\t eval/downstream/sciq (log soft loss)\\n\",\n", + " \"\\t eval/downstream/sciq (soft loss)\\n\",\n", + " \"\\t eval/downstream/social_iqa (BPB)\\n\",\n", + " \"\\t eval/downstream/social_iqa (CE loss)\\n\",\n", + " \"\\t eval/downstream/social_iqa (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/social_iqa (log soft loss)\\n\",\n", + " \"\\t eval/downstream/social_iqa (soft loss)\\n\",\n", + " \"\\t eval/downstream/socialiqa_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/socialiqa_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/socialiqa_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/socialiqa_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/socialiqa_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/socialiqa_val_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/socialiqa_val_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/socialiqa_val_mc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/socialiqa_val_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/socialiqa_val_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/socialiqa_val_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/socialiqa_val_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/socialiqa_val_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/socialiqa_val_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/socialiqa_val_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/winogrande (BPB)\\n\",\n", + " \"\\t eval/downstream/winogrande (CE loss)\\n\",\n", + " \"\\t eval/downstream/winogrande (accuracy)\\n\",\n", + " \"\\t eval/downstream/winogrande (log soft loss)\\n\",\n", + " \"\\t eval/downstream/winogrande (soft loss)\\n\",\n", + " \"\\t eval/downstream/winogrande_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/winogrande_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/winogrande_rc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/winogrande_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/winogrande_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/winogrande_val_mc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/winogrande_val_mc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/winogrande_val_mc_5shot (accuracy)\\n\",\n", + " \"\\t eval/downstream/winogrande_val_mc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/winogrande_val_mc_5shot (soft loss)\\n\",\n", + " \"\\t eval/downstream/winogrande_val_rc_5shot (BPB)\\n\",\n", + " \"\\t eval/downstream/winogrande_val_rc_5shot (CE loss)\\n\",\n", + " \"\\t eval/downstream/winogrande_val_rc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/winogrande_val_rc_5shot (log soft loss)\\n\",\n", + " \"\\t eval/downstream/winogrande_val_rc_5shot (soft loss)\\n\",\n", + " \"\\t eval/lm/c4_en-validation/CE loss\\n\",\n", + " \"\\t eval/lm/c4_en-validation/PPL\\n\",\n", + " \"\\t eval/lm/dolma_books-validation/CE loss\\n\",\n", + " \"\\t eval/lm/dolma_books-validation/PPL\\n\",\n", + " \"\\t eval/lm/dolma_common-crawl-validation/CE loss\\n\",\n", + " \"\\t eval/lm/dolma_common-crawl-validation/PPL\\n\",\n", + " \"\\t eval/lm/dolma_pes2o-validation/CE loss\\n\",\n", + " \"\\t eval/lm/dolma_pes2o-validation/PPL\\n\",\n", + " \"\\t eval/lm/dolma_reddit-validation/CE loss\\n\",\n", + " \"\\t eval/lm/dolma_reddit-validation/PPL\\n\",\n", + " \"\\t eval/lm/dolma_stack-validation/CE loss\\n\",\n", + " \"\\t eval/lm/dolma_stack-validation/PPL\\n\",\n", + " \"\\t eval/lm/dolma_wiki-validation/CE loss\\n\",\n", + " \"\\t eval/lm/dolma_wiki-validation/PPL\\n\",\n", + " \"\\t eval/lm/ice-validation/CE loss\\n\",\n", + " \"\\t eval/lm/ice-validation/PPL\\n\",\n", + " \"\\t eval/lm/m2d2_s2orc-validation/CE loss\\n\",\n", + " \"\\t eval/lm/m2d2_s2orc-validation/PPL\\n\",\n", + " \"\\t eval/lm/pile-validation/CE loss\\n\",\n", + " \"\\t eval/lm/pile-validation/PPL\\n\",\n", + " \"\\t eval/lm/wikitext_103-validation/CE loss\\n\",\n", + " \"\\t eval/lm/wikitext_103-validation/PPL\\n\",\n", + " \"\\t optim/LR (group 0)\\n\",\n", + " \"\\t optim/LR (group 1)\\n\",\n", + " \"\\t optim/step skipped\\n\",\n", + " \"\\t optim/total grad norm\\n\",\n", + " \"\\t sys.compute.overall\\n\",\n", + " \"\\t sys.compute.utilized\\n\",\n", + " \"\\t sys.cpu.percent.avg\\n\",\n", + " \"\\t sys.disk.read_bps\\n\",\n", + " \"\\t sys.disk.root.percent.used\\n\",\n", + " \"\\t sys.disk.root.used\\n\",\n", + " \"\\t sys.disk.write_bps\\n\",\n", + " \"\\t sys.gpu.0.free_memory\\n\",\n", + " \"\\t sys.gpu.0.gpu_utilization\\n\",\n", + " \"\\t sys.gpu.0.memory_utilization\\n\",\n", + " \"\\t sys.gpu.0.percent.used_memory\\n\",\n", + " \"\\t sys.gpu.0.power_usage\\n\",\n", + " \"\\t sys.gpu.0.temperature\\n\",\n", + " \"\\t sys.gpu.0.total_memory\\n\",\n", + " \"\\t sys.gpu.0.used_memory\\n\",\n", + " \"\\t sys.gpu.1.free_memory\\n\",\n", + " \"\\t sys.gpu.1.gpu_utilization\\n\",\n", + " \"\\t sys.gpu.1.memory_utilization\\n\",\n", + " \"\\t sys.gpu.1.percent.used_memory\\n\",\n", + " \"\\t sys.gpu.1.power_usage\\n\",\n", + " \"\\t sys.gpu.1.temperature\\n\",\n", + " \"\\t sys.gpu.1.total_memory\\n\",\n", + " \"\\t sys.gpu.1.used_memory\\n\",\n", + " \"\\t sys.gpu.2.free_memory\\n\",\n", + " \"\\t sys.gpu.2.gpu_utilization\\n\",\n", + " \"\\t sys.gpu.2.memory_utilization\\n\",\n", + " \"\\t sys.gpu.2.percent.used_memory\\n\",\n", + " \"\\t sys.gpu.2.power_usage\\n\",\n", + " \"\\t sys.gpu.2.temperature\\n\",\n", + " \"\\t sys.gpu.2.total_memory\\n\",\n", + " \"\\t sys.gpu.2.used_memory\\n\",\n", + " \"\\t sys.gpu.3.free_memory\\n\",\n", + " \"\\t sys.gpu.3.gpu_utilization\\n\",\n", + " \"\\t sys.gpu.3.memory_utilization\\n\",\n", + " \"\\t sys.gpu.3.percent.used_memory\\n\",\n", + " \"\\t sys.gpu.3.power_usage\\n\",\n", + " \"\\t sys.gpu.3.temperature\\n\",\n", + " \"\\t sys.gpu.3.total_memory\\n\",\n", + " \"\\t sys.gpu.3.used_memory\\n\",\n", + " \"\\t sys.gpu.4.free_memory\\n\",\n", + " \"\\t sys.gpu.4.gpu_utilization\\n\",\n", + " \"\\t sys.gpu.4.memory_utilization\\n\",\n", + " \"\\t sys.gpu.4.percent.used_memory\\n\",\n", + " \"\\t sys.gpu.4.power_usage\\n\",\n", + " \"\\t sys.gpu.4.temperature\\n\",\n", + " \"\\t sys.gpu.4.total_memory\\n\",\n", + " \"\\t sys.gpu.4.used_memory\\n\",\n", + " \"\\t sys.gpu.5.free_memory\\n\",\n", + " \"\\t sys.gpu.5.gpu_utilization\\n\",\n", + " \"\\t sys.gpu.5.memory_utilization\\n\",\n", + " \"\\t sys.gpu.5.percent.used_memory\\n\",\n", + " \"\\t sys.gpu.5.power_usage\\n\",\n", + " \"\\t sys.gpu.5.temperature\\n\",\n", + " \"\\t sys.gpu.5.total_memory\\n\",\n", + " \"\\t sys.gpu.5.used_memory\\n\",\n", + " \"\\t sys.gpu.6.free_memory\\n\",\n", + " \"\\t sys.gpu.6.gpu_utilization\\n\",\n", + " \"\\t sys.gpu.6.memory_utilization\\n\",\n", + " \"\\t sys.gpu.6.percent.used_memory\\n\",\n", + " \"\\t sys.gpu.6.power_usage\\n\",\n", + " \"\\t sys.gpu.6.temperature\\n\",\n", + " \"\\t sys.gpu.6.total_memory\\n\",\n", + " \"\\t sys.gpu.6.used_memory\\n\",\n", + " \"\\t sys.gpu.7.free_memory\\n\",\n", + " \"\\t sys.gpu.7.gpu_utilization\\n\",\n", + " \"\\t sys.gpu.7.memory_utilization\\n\",\n", + " \"\\t sys.gpu.7.percent.used_memory\\n\",\n", + " \"\\t sys.gpu.7.power_usage\\n\",\n", + " \"\\t sys.gpu.7.temperature\\n\",\n", + " \"\\t sys.gpu.7.total_memory\\n\",\n", + " \"\\t sys.gpu.7.used_memory\\n\",\n", + " \"\\t sys.load.avg\\n\",\n", + " \"\\t sys.network.receive_bps\\n\",\n", + " \"\\t sys.network.send_bps\\n\",\n", + " \"\\t sys.ram.available\\n\",\n", + " \"\\t sys.ram.percent.used\\n\",\n", + " \"\\t sys.ram.total\\n\",\n", + " \"\\t sys.ram.used\\n\",\n", + " \"\\t system/GPU active mem (%)\\n\",\n", + " \"\\t system/GPU active mem (GiB)\\n\",\n", + " \"\\t system/GPU reserved mem (%)\\n\",\n", + " \"\\t system/GPU reserved mem (GiB)\\n\",\n", + " \"\\t throughput/device/BPS\\n\",\n", + " \"\\t throughput/device/BPS (actual avg)\\n\",\n", + " \"\\t throughput/device/TPS\\n\",\n", + " \"\\t throughput/device/TPS (actual avg)\\n\",\n", + " \"\\t throughput/device/data loading (%)\\n\",\n", + " \"\\t throughput/device/data loading (s)\\n\",\n", + " \"\\t throughput/total tokens\\n\",\n", + " \"\\t train/CE loss\\n\",\n", + " \"\\t train/PPL\\n\",\n", + " \"\\t train/Z loss\\n\",\n", + " \"peteish13:\\n\",\n", + " \"\\t eval/downstream/arc_challenge (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/arc_easy (accuracy)\\n\",\n", + " \"\\t eval/downstream/basic_arithmetic (accuracy)\\n\",\n", + " \"\\t eval/downstream/boolq (accuracy)\\n\",\n", + " \"\\t eval/downstream/commonsense_qa (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/copa (accuracy)\\n\",\n", + " \"\\t eval/downstream/hellaswag (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_mc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_mc_5shot_test (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_humanities_var (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_mc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_mc_5shot_test (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_other_var (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_mc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_mc_5shot_test (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_social_sciences_var (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_mc_5shot (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_mc_5shot_test (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/mmlu_stem_var (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/openbook_qa (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/piqa (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/sciq (accuracy)\\n\",\n", + " \"\\t eval/downstream/social_iqa (length-normalized accuracy)\\n\",\n", + " \"\\t eval/downstream/winogrande (accuracy)\\n\",\n", + " \"\\t optim/LR (group 0)\\n\",\n", + " \"\\t optim/LR (group 1)\\n\",\n", + " \"\\t optim/total grad norm\\n\",\n", + " \"\\t sys.compute.overall\\n\",\n", + " \"\\t sys.compute.utilized\\n\",\n", + " \"\\t sys.cpu.percent.avg\\n\",\n", + " \"\\t sys.disk.read_bps\\n\",\n", + " \"\\t sys.disk.root.percent.used\\n\",\n", + " \"\\t sys.disk.root.used\\n\",\n", + " \"\\t sys.disk.write_bps\\n\",\n", + " \"\\t sys.gpu.0.free_memory\\n\",\n", + " \"\\t sys.gpu.0.gpu_utilization\\n\",\n", + " \"\\t sys.gpu.0.memory_utilization\\n\",\n", + " \"\\t sys.gpu.0.percent.used_memory\\n\",\n", + " \"\\t sys.gpu.0.power_usage\\n\",\n", + " \"\\t sys.gpu.0.temperature\\n\",\n", + " \"\\t sys.gpu.0.total_memory\\n\",\n", + " \"\\t sys.gpu.0.used_memory\\n\",\n", + " \"\\t sys.gpu.1.free_memory\\n\",\n", + " \"\\t sys.gpu.1.gpu_utilization\\n\",\n", + " \"\\t sys.gpu.1.memory_utilization\\n\",\n", + " \"\\t sys.gpu.1.percent.used_memory\\n\",\n", + " \"\\t sys.gpu.1.power_usage\\n\",\n", + " \"\\t sys.gpu.1.temperature\\n\",\n", + " \"\\t sys.gpu.1.total_memory\\n\",\n", + " \"\\t sys.gpu.1.used_memory\\n\",\n", + " \"\\t sys.gpu.2.free_memory\\n\",\n", + " \"\\t sys.gpu.2.gpu_utilization\\n\",\n", + " \"\\t sys.gpu.2.memory_utilization\\n\",\n", + " \"\\t sys.gpu.2.percent.used_memory\\n\",\n", + " \"\\t sys.gpu.2.power_usage\\n\",\n", + " \"\\t sys.gpu.2.temperature\\n\",\n", + " \"\\t sys.gpu.2.total_memory\\n\",\n", + " \"\\t sys.gpu.2.used_memory\\n\",\n", + " \"\\t sys.gpu.3.free_memory\\n\",\n", + " \"\\t sys.gpu.3.gpu_utilization\\n\",\n", + " \"\\t sys.gpu.3.memory_utilization\\n\",\n", + " \"\\t sys.gpu.3.percent.used_memory\\n\",\n", + " \"\\t sys.gpu.3.power_usage\\n\",\n", + " \"\\t sys.gpu.3.temperature\\n\",\n", + " \"\\t sys.gpu.3.total_memory\\n\",\n", + " \"\\t sys.gpu.3.used_memory\\n\",\n", + " \"\\t sys.load.avg\\n\",\n", + " \"\\t sys.network.receive_bps\\n\",\n", + " \"\\t sys.network.send_bps\\n\",\n", + " \"\\t sys.ram.available\\n\",\n", + " \"\\t sys.ram.percent.used\\n\",\n", + " \"\\t sys.ram.total\\n\",\n", + " \"\\t sys.ram.used\\n\",\n", + " \"\\t throughput/device/BPS\\n\",\n", + " \"\\t throughput/device/TPS\\n\",\n", + " \"\\t train/CE loss\\n\",\n", + " \"\\t train/PPL\\n\",\n", + " \"\\t train/Z loss\\n\",\n", + " \"peteish7:\\n\",\n", + " \"\\t optim/LR (group 0)\\n\",\n", + " \"\\t optim/LR (group 1)\\n\",\n", + " \"\\t optim/total grad norm\\n\",\n", + " \"\\t sys.compute.overall\\n\",\n", + " \"\\t sys.compute.utilized\\n\",\n", + " \"\\t sys.cpu.percent.avg\\n\",\n", + " \"\\t sys.disk.read_bps\\n\",\n", + " \"\\t sys.disk.root.percent.used\\n\",\n", + " \"\\t sys.disk.root.used\\n\",\n", + " \"\\t sys.disk.write_bps\\n\",\n", + " \"\\t sys.load.avg\\n\",\n", + " \"\\t sys.network.receive_bps\\n\",\n", + " \"\\t sys.network.send_bps\\n\",\n", + " \"\\t sys.ram.available\\n\",\n", + " \"\\t sys.ram.percent.used\\n\",\n", + " \"\\t sys.ram.total\\n\",\n", + " \"\\t sys.ram.used\\n\",\n", + " \"\\t throughput/device/BPS\\n\",\n", + " \"\\t throughput/device/TPS\\n\",\n", + " \"\\t train/CE loss\\n\",\n", + " \"\\t train/PPL\\n\",\n", + " \"\\t train/Z loss\\n\"\n", + " ]\n", + " }\n", + " ],\n", + " \"execution_count\": 3\n", + " },\n", + " {\n", + " \"metadata\": {\n", + " \"ExecuteTime\": {\n", + " \"end_time\": \"2025-01-05T19:03:16.729657Z\",\n", + " \"start_time\": \"2025-01-05T19:02:41.776673Z\"\n", + " }\n", + " },\n", + " \"cell_type\": \"code\",\n", + " \"source\": [\n", + " \"from tqdm.notebook import tqdm\\n\",\n", + " \"\\n\",\n", + " \"def download_metric(exps, metric_name):\\n\",\n", + " \" result = {}\\n\",\n", + " \" for exp in tqdm(exps):\\n\",\n", + " \" metrics = exp.get_metrics(metric_name)\\n\",\n", + " \" for values in metrics:\\n\",\n", + " \" result[values['step']] = float(values['metricValue'])\\n\",\n", + " \" result = dict(sorted(result.items()))\\n\",\n", + " \" return result\\n\",\n", + " \"\\n\",\n", + " \"loss = {\\n\",\n", + " \" name: download_metric(es, \\\"train/CE loss\\\")\\n\",\n", + " \" for name, es in exps.items()\\n\",\n", + " \"}\\n\",\n", + " \"\\n\",\n", + " \"skipped_steps = {\\n\",\n", + " \" name: download_metric(es, \\\"optim/step skipped\\\")\\n\",\n", + " \" for name, es in exps.items()\\n\",\n", + " \"}\\n\",\n", + " \"\\n\",\n", + " \"speed = {\\n\",\n", + " \" name: download_metric(es, \\\"train/CE loss\\\")\\n\",\n", + " \" for name, es in exps.items()\\n\",\n", + " \"}\"\n", + " ],\n", + " \"id\": \"6aa86a5638253061\",\n", + " \"outputs\": [\n", + " {\n", + " \"data\": {\n", + " \"text/plain\": [\n", + " \" 0%| | 0/50 [00:00 0])\"\n", + " ],\n", + " \"id\": \"277e0e889edb7b16\",\n", + " \"outputs\": [\n", + " {\n", + " \"data\": {\n", + " \"text/plain\": [\n", + " \"
\"\n", + " ],\n", + " \"image/svg+xml\": \"\\n\\n\\n \\n \\n \\n \\n 2025-01-05T11:07:13.567234\\n image/svg+xml\\n \\n \\n Matplotlib v3.9.2, https://matplotlib.org/\\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n\\n\"\n", + " },\n", + " \"metadata\": {},\n", + " \"output_type\": \"display_data\"\n", + " },\n", + " {\n", + " \"name\": \"stdout\",\n", + " \"output_type\": \"stream\",\n", + " \"text\": [\n", + " \"Steps skipped for the 32B: 47\\n\",\n", + " \"[848, 1401, 80788, 81072, 84048, 85129, 87386, 92844, 107316, 111491, 113030, 114230, 118668, 121925, 126863, 127493, 128136, 129747, 134843, 136385, 142362, 142815, 144303, 144548, 147139, 147455, 148216, 148703, 150206, 154267, 159678, 159881, 160407, 163682, 167141, 167784, 175621, 187888, 188783, 194308, 200682, 201311, 204820, 205830, 206617, 211141, 212691]\\n\"\n", + " ]\n", + " }\n", + " ],\n", + " \"execution_count\": 8\n", + " },\n", + " {\n", + " \"metadata\": {},\n", + " \"cell_type\": \"markdown\",\n", + " \"source\": \"## Downstream\",\n", + " \"id\": \"83cbde8bd1160629\"\n", + " },\n", + " {\n", + " \"metadata\": {\n", + " \"ExecuteTime\": {\n", + " \"end_time\": \"2025-01-05T19:07:12.880276Z\",\n", + " \"start_time\": \"2025-01-05T19:03:17.042699Z\"\n", + " }\n", + " },\n", + " \"cell_type\": \"code\",\n", + " \"source\": [\n", + " \"aggregate_metric_definitions = {\\n\",\n", + " \" \\\"MMLU 5-shot MC\\\": {\\n\",\n", + " \" \\\"eval/downstream/mmlu_stem_mc_5shot (length-normalized accuracy)\\\": 0.215,\\n\",\n", + " \" \\\"eval/downstream/mmlu_humanities_mc_5shot (length-normalized accuracy)\\\": 0.335,\\n\",\n", + " \" \\\"eval/downstream/mmlu_social_sciences_mc_5shot (length-normalized accuracy)\\\": 0.219,\\n\",\n", + " \" \\\"eval/downstream/mmlu_other_mc_5shot (length-normalized accuracy)\\\": 0.231\\n\",\n", + " \" },\\n\",\n", + " \" \\\"Average of core 12\\\": {\\n\",\n", + " \" \\\"eval/downstream/arc_challenge (length-normalized accuracy)\\\": 1 / 12,\\n\",\n", + " \" \\\"eval/downstream/arc_easy (accuracy)\\\": 1 / 12,\\n\",\n", + " \" \\\"eval/downstream/basic_arithmetic (accuracy)\\\": 1 / 12,\\n\",\n", + " \" \\\"eval/downstream/boolq (accuracy)\\\": 1 / 12,\\n\",\n", + " \" \\\"eval/downstream/commonsense_qa (length-normalized accuracy)\\\": 1 / 12,\\n\",\n", + " \" \\\"eval/downstream/copa (accuracy)\\\": 1 / 12,\\n\",\n", + " \" \\\"eval/downstream/hellaswag (length-normalized accuracy)\\\": 1 / 12,\\n\",\n", + " \" \\\"eval/downstream/openbook_qa (length-normalized accuracy)\\\": 1 / 12,\\n\",\n", + " \" \\\"eval/downstream/piqa (length-normalized accuracy)\\\": 1 / 12,\\n\",\n", + " \" \\\"eval/downstream/sciq (accuracy)\\\": 1 / 12,\\n\",\n", + " \" \\\"eval/downstream/social_iqa (length-normalized accuracy)\\\": 1 / 12,\\n\",\n", + " \" \\\"eval/downstream/winogrande (accuracy)\\\": 1 / 12,\\n\",\n", + " \" },\\n\",\n", + " \" \\\"Hellswag\\\": {\\n\",\n", + " \" \\\"eval/downstream/hellaswag (length-normalized accuracy)\\\": 1\\n\",\n", + " \" }\\n\",\n", + " \"}\\n\",\n", + " \"\\n\",\n", + " \"import matplotlib.pyplot as plt\\n\",\n", + " \"%matplotlib inline\\n\",\n", + " \"%config InlineBackend.figure_format = 'svg'\\n\",\n", + " \"import numpy as np\\n\",\n", + " \"\\n\",\n", + " \"fig, axs = plt.subplots(nrows=len(aggregate_metric_definitions), sharex=True, figsize=(10, len(aggregate_metric_definitions)*3))\\n\",\n", + " \"\\n\",\n", + " \"for ax, agg_metric_name in zip(axs, aggregate_metric_definitions):\\n\",\n", + " \" metric_to_weight = aggregate_metric_definitions[agg_metric_name]\\n\",\n", + " \" for run_name, run_exps in exps.items():\\n\",\n", + " \" metric_to_values = {}\\n\",\n", + " \" for metric in metric_to_weight.keys():\\n\",\n", + " \" metric_to_values[metric] = download_metric(run_exps, metric)\\n\",\n", + " \"\\n\",\n", + " \" all_steps = set.union(*[set(v.keys()) for v in metric_to_values.values()])\\n\",\n", + " \" minimal_steps = set.intersection(*[set(v.keys()) for v in metric_to_values.values()])\\n\",\n", + " \" if all_steps != minimal_steps:\\n\",\n", + " \" print(f\\\"Missing steps for {run_name} / {agg_metric_name}: {all_steps - minimal_steps}\\\")\\n\",\n", + " \"\\n\",\n", + " \" aggregated_values = {}\\n\",\n", + " \" for step in minimal_steps:\\n\",\n", + " \" value = 0.0\\n\",\n", + " \" for metric, weight in metric_to_weight.items():\\n\",\n", + " \" value += metric_to_values[metric][step] * weight\\n\",\n", + " \" aggregated_values[step] = value\\n\",\n", + " \" if len(aggregated_values) == 0:\\n\",\n", + " \" continue\\n\",\n", + " \"\\n\",\n", + " \" print(f\\\"{run_name} / {agg_metric_name} max: {max(aggregated_values.values())}\\\")\\n\",\n", + " \"\\n\",\n", + " \" xs = np.array(list(aggregated_values.keys()))\\n\",\n", + " \" ys = np.array(list(aggregated_values.values()))\\n\",\n", + " \" order = np.argsort(xs)\\n\",\n", + " \" xs = xs[order]\\n\",\n", + " \" ys = ys[order]\\n\",\n", + " \" xs *= (2048 * 4096)\\n\",\n", + " \" ax.plot(xs, ys, linewidth=0.5)\\n\",\n", + " \" ax.set_ylabel(agg_metric_name)\\n\",\n", + " \"\\n\",\n", + " \"plt.xlabel(\\\"step\\\")\\n\",\n", + " \"plt.show()\"\n", + " ],\n", + " \"id\": \"8b310d9cc68ad856\",\n", + " \"outputs\": [\n", + " {\n", + " \"data\": {\n", + " \"text/plain\": [\n", + " \" 0%| | 0/50 [00:00\"\n", + " ],\n", + " \"image/svg+xml\": \"\\n\\n\\n \\n \\n \\n \\n 2025-01-05T11:07:12.858839\\n image/svg+xml\\n \\n \\n Matplotlib v3.9.2, https://matplotlib.org/\\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n\\n\"\n", + " },\n", + " \"metadata\": {},\n", + " \"output_type\": \"display_data\"\n", + " }\n", + " ],\n", + " \"execution_count\": 6\n", + " },\n", + " {\n", + " \"metadata\": {},\n", + " \"cell_type\": \"markdown\",\n", + " \"source\": \"## Spike Analysis\",\n", + " \"id\": \"744574cd19bbe369\"\n", + " },\n", + " {\n", + " \"metadata\": {\n", + " \"ExecuteTime\": {\n", + " \"end_time\": \"2025-01-05T19:07:13.429408Z\",\n", + " \"start_time\": \"2025-01-05T19:07:12.942046Z\"\n", + " }\n", + " },\n", + " \"cell_type\": \"code\",\n", + " \"source\": [\n", + " \"window_size = 128\\n\",\n", + " \"losses = np.array(list(loss[\\\"peteish32\\\"].values()))\\n\",\n", + " \"steps = np.array(list(loss[\\\"peteish32\\\"].keys()))\\n\",\n", + " \"\\n\",\n", + " \"from numpy.lib.stride_tricks import sliding_window_view\\n\",\n", + " \"windows = sliding_window_view(losses, window_size)\\n\",\n", + " \"\\n\",\n", + " \"stds = windows.std(axis=1)\\n\",\n", + " \"means = windows.mean(axis=1)\\n\",\n", + " \"losses = losses[window_size - 1 :]\\n\",\n", + " \"steps = steps[window_size - 1 :]\\n\",\n", + " \"spike_steps = steps[np.argwhere(losses > means + stds * 6)].flatten()\\n\",\n", + " \"print(f\\\"Steps with spikes: {spike_steps}\\\")\\n\",\n", + " \"\\n\",\n", + " \"fig, axes = plt.subplots(\\n\",\n", + " \" nrows=len(spike_steps),\\n\",\n", + " \" figsize=(7, len(spike_steps)*3),\\n\",\n", + " \" sharex=False\\n\",\n", + " \")\\n\",\n", + " \"\\n\",\n", + " \"for ax, spike in zip(axes, spike_steps):\\n\",\n", + " \" for name, values in loss.items():\\n\",\n", + " \" xs = np.array(list(values.keys()))\\n\",\n", + " \" ys = np.array(list(values.values()))\\n\",\n", + " \" ax.plot(xs, ys, linewidth=0.5)\\n\",\n", + " \" ax.set_ylim(2.1, 2.5)\\n\",\n", + " \" ax.set_xlim(spike-1000, spike+1000)\\n\",\n", + " \" plt.yscale('log')\\n\",\n", + " \" plt.xlabel(\\\"step\\\")\\n\",\n", + " \" plt.ylabel(\\\"loss\\\")\\n\",\n", + " \"\\n\",\n", + " \"plt.tight_layout()\\n\",\n", + " \"plt.show()\\n\"\n", + " ],\n", + " \"id\": \"6eb5abfb647663a5\",\n", + " \"outputs\": [\n", + " {\n", + " \"name\": \"stdout\",\n", + " \"output_type\": \"stream\",\n", + " \"text\": [\n", + " \"Steps with spikes: [ 29645 38677 49089 54503 66257 73019 144302]\\n\"\n", + " ]\n", + " },\n", + " {\n", + " \"data\": {\n", + " \"text/plain\": [\n", + " \"
\"\n", + " ],\n", + " \"image/svg+xml\": \"\\n\\n\\n \\n \\n \\n \\n 2025-01-05T11:07:13.365572\\n image/svg+xml\\n \\n \\n Matplotlib v3.9.2, https://matplotlib.org/\\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n\\n\"\n", + " },\n", + " \"metadata\": {},\n", + " \"output_type\": \"display_data\"\n", + " }\n", + " ],\n", + " \"execution_count\": 7\n", + " }\n", + " ],\n", + " \"metadata\": {\n", + " \"kernelspec\": {\n", + " \"display_name\": \"Python 3\",\n", + " \"language\": \"python\",\n", + " \"name\": \"python3\"\n", + " },\n", + " \"language_info\": {\n", + " \"codemirror_mode\": {\n", + " \"name\": \"ipython\",\n", + " \"version\": 2\n", + " },\n", + " \"file_extension\": \".py\",\n", + " \"mimetype\": \"text/x-python\",\n", + " \"name\": \"python\",\n", + " \"nbconvert_exporter\": \"python\",\n", + " \"pygments_lexer\": \"ipython2\",\n", + " \"version\": \"2.7.6\"\n", + " }\n", + " },\n", + " \"nbformat\": 4,\n", + " \"nbformat_minor\": 5\n", + "}\n" ], - "execution_count": 7 + "id": "7b448fb9dfd7a97d" } ], "metadata": { diff --git a/src/scripts/train/OLMo2-32B.py b/src/scripts/train/OLMo2-32B.py index 7f82210e..b4ba54d7 100644 --- a/src/scripts/train/OLMo2-32B.py +++ b/src/scripts/train/OLMo2-32B.py @@ -7,7 +7,7 @@ from olmo_core.config import DType from olmo_core.distributed.parallel import DataParallelType from olmo_core.float8 import Float8Config -from olmo_core.internal.experiment import CommonComponents, main +from olmo_core.internal.experiment import CommonComponents, main, ExperimentConfig from olmo_core.nn.transformer import ( TransformerActivationCheckpointingConfig, TransformerActivationCheckpointingMode, @@ -26,6 +26,7 @@ log = logging.getLogger(__name__) +NUM_NODES = 16 def build_model_config(common: CommonComponents) -> TransformerConfig: compile = True @@ -35,18 +36,22 @@ def build_model_config(common: CommonComponents) -> TransformerConfig: fused_ops=False, use_flash=not compile, dp_config=TransformerDataParallelConfig( - name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 + name=DataParallelType.fsdp, + param_dtype=DType.bfloat16, + reduce_dtype=DType.float32, + num_replicas=NUM_NODES // 2, ), - # dp_config=TransformerDataParallelConfig( - # name=DataParallelType.hsdp, - # param_dtype=DType.bfloat16, - # reduce_dtype=DType.float32, - # num_replicas=64 // 16, # common.launch.num_nodes // 2, - # ), - # ac_config=TransformerActivationCheckpointingConfig(TransformerActivationCheckpointingMode.full), ac_config=TransformerActivationCheckpointingConfig( mode=TransformerActivationCheckpointingMode.selected_modules, modules=[f"blocks.{i}.feed_forward" for i in range(64)], + #modules=[ + # "embeddings", + # "blocks.*.attention", + # "blocks.*.attention_norm", + # "blocks.*.feed_forward.w1", + # "blocks.*.feed_forward.w3", + # "blocks.*.feed_forward_norm" + #] ), float8_config=Float8Config(compile=compile, enabled=False), ) @@ -62,12 +67,12 @@ def build_optim_config(common: CommonComponents) -> SkipStepAdamWConfig: OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0)) ], # fused=True, - compile=True, + compile=False, ) def build_trainer_config(common: CommonComponents) -> TrainerConfig: - project_name = "peteish32" + project_name = "peteish32-hybrid" return ( TrainerConfig( save_folder=f"gs://ai2-llm/checkpoints/{project_name}/", @@ -94,7 +99,7 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig: CometCallback( name=common.run_name, workspace="ai2", - project=project_name, + project="peteish32", enabled=True, cancel_check_interval=10, ), @@ -104,7 +109,7 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig: WandBCallback( name=common.run_name, entity="ai2-llm", - project=project_name, + project="peteish32", enabled=False, cancel_check_interval=10, ), @@ -190,10 +195,27 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig: ) +def finalize_config(config: ExperimentConfig) -> None: + if config.trainer.load_path is not None and config.trainer.load_path.startswith("gs://"): + source_path = config.trainer.load_path.rstrip("/") + assert len(source_path) > 0 + final_path_component = source_path.rsplit("/", maxsplit=1)[-1] + assert len(final_path_component) > 0 + target_path = f"/data/olmo_core/{final_path_component}" + assert len(target_path) > 0 # just to be extra sure, because we're rm'ing it below + config.launch.setup_steps.extend([ + f"rm -rf {target_path}", + f"mkdir -p {target_path}", + f"gsutil -q -m rsync -r {source_path} {target_path}" + ]) + config.trainer.load_path = target_path + + if __name__ == "__main__": main( - global_batch_size=2048 * 4096, + global_batch_size=2 * 4096 * NUM_NODES * 8, model_config_builder=build_model_config, optim_config_builder=build_optim_config, trainer_config_builder=build_trainer_config, + finalize_config=finalize_config, ) diff --git a/src/scripts/unshard.py b/src/scripts/unshard.py new file mode 100644 index 00000000..d54215c2 --- /dev/null +++ b/src/scripts/unshard.py @@ -0,0 +1,21 @@ +import argparse + +from olmo_core.distributed.checkpoint import unshard_checkpoint +from olmo_core.utils import prepare_cli_environment, LogFilterType + + +def main(): + prepare_cli_environment(LogFilterType.all_ranks) + + parser = argparse.ArgumentParser(description='Unshard a checkpoint') + parser.add_argument('directory', help='directory containing the checkpoint') + parser.add_argument('-o', '--output', help='output directory', default=None) + args = parser.parse_args() + if args.output is None: + args.output = f"{args.directory.rstrip('/')}_unsharded" + + unshard_checkpoint(args.directory, args.output, optim=True) + + +if __name__ == '__main__': + main() \ No newline at end of file