diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 4ea8787a..a7511bcd 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -5,6 +5,7 @@ from pathlib import Path import argparse import datetime +import functools import logging import math import os @@ -544,6 +545,28 @@ def train( ) +# This function makes an effort to stick to a default value from torch library, +# whatever it may be. That's why we don't just set to the current (as of the +# time of writing) default: to cover the unlikely event torch decides to tweak +# the default. +def _get_collective_timeout() -> datetime.timedelta | None: + timeout_var = os.getenv("INSTRUCTLAB_NCCL_TIMEOUT_MS") + if timeout_var is None: + return None + + try: + timeout = int(timeout_var) + except ValueError: + timeout = -1 + + if timeout <= 0: + raise ValueError( + f"Invalid value for INSTRUCTLAB_NCCL_TIMEOUT_MS: {timeout_var}. Must be a positive integer." + ) + + return datetime.timedelta(milliseconds=timeout) + + def main(args): if args.distributed_training_framework == "deepspeed" and not FusedAdam: raise ImportError( @@ -571,15 +594,17 @@ def main(args): model_conf = AutoConfig.from_pretrained(args.model_name_or_path) args.model_type = model_conf.model_type - # solution discovered from torchtune https://github.com/pytorch/torchtune/issues/2093 - # gets converted to a timedelta of 1:40:00 if the default is kept - nccl_timeout = int(os.getenv("INSTRUCTLAB_NCCL_TIMEOUT_MS", "6000000")) #### distributed init ##### torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) args.local_rank = int(os.environ["LOCAL_RANK"]) - torch.distributed.init_process_group( - "nccl", timeout=datetime.timedelta(milliseconds=nccl_timeout) - ) + + timeout = _get_collective_timeout() + init = functools.partial(torch.distributed.init_process_group, "nccl") + if timeout is not None: + init(timeout=timeout) + else: + init() + args.global_rank = torch.distributed.get_rank() tensor = torch.ByteTensor([False]).cuda() torch.distributed.all_reduce(tensor) diff --git a/tests/unit/test_main_ds.py b/tests/unit/test_main_ds.py new file mode 100644 index 00000000..12b35127 --- /dev/null +++ b/tests/unit/test_main_ds.py @@ -0,0 +1,39 @@ +# Standard +from unittest import mock +import datetime + +# Third Party +import pytest + +# First Party +from instructlab.training import main_ds + + +def test__get_collective_timeout(): + # Test with default timeout + assert main_ds._get_collective_timeout() is None + + # Test with custom timeout + timeout = 1234 + with mock.patch.dict( + main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": str(timeout)} + ): + assert main_ds._get_collective_timeout() == datetime.timedelta( + milliseconds=timeout + ) + + # Test with invalid timeout (negative) + invalid_timeout = "-100" + with mock.patch.dict( + main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": invalid_timeout} + ): + with pytest.raises(ValueError): + main_ds._get_collective_timeout() + + # Test with invalid timeout (string) + invalid_timeout = "invalid" + with mock.patch.dict( + main_ds.os.environ, {"INSTRUCTLAB_NCCL_TIMEOUT_MS": invalid_timeout} + ): + with pytest.raises(ValueError): + main_ds._get_collective_timeout()