diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 37f4fe58..d2459739 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -185,6 +185,7 @@ class TrainingArgs(BaseModel): mock_data: Optional[bool] = False mock_data_len: int = 0 + mock_num_samples: int = 0 deepspeed_options: DeepSpeedOptions = Field( default_factory=lambda: DeepSpeedOptions( @@ -228,3 +229,9 @@ class TrainingArgs(BaseModel): default=False, description="Whether to use Liger kernels for training.", ) + + # TODO(osilkin): Create a better API for this, should not merge into library this way + use_multipack_v2: bool = Field( + default=False, + description="Use the MultipackV2 sampler which balances batches based on computational cost. Does not support Padding transformers.", + ) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 1266b0bd..15a81482 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -57,6 +57,9 @@ from instructlab.training.multipack_sampler import ( find_packing_max_batch_len_and_grad_accum, ) +from instructlab.training.multipack_sampler_v2 import ( + find_packing_max_batch_len_and_grad_accum as find_packing_max_batch_len_and_grad_accum_v2, +) from instructlab.training.setup_accelerator import setup_accelerator from instructlab.training.token_dataset import setup_dataloader, setup_dataset from instructlab.training.tokenizer_utils import setup_tokenizer @@ -379,9 +382,15 @@ def train( else None ) + # variables for tracking statistics global_grad_norm = None + stats_momentum = 0.999 # average strength is effectively 1/1000 + avg_throughput = 0.0 + avg_time_per_step = 0.0 + num_batches = None + for epoch in range(args.current_epoch, args.num_epochs): - if args.sampler in ("multipack"): + if args.sampler in ("multipack", "multipack_v2"): train_loader.batch_sampler.set_epoch(epoch) elif args.sampler in ("distributed"): train_loader.sampler.set_epoch(epoch) @@ -393,6 +402,9 @@ def train( # blast through the batches in the train loader up to the last step within the epoch. for batch in train_loader: + if not num_batches: + num_batches = len(train_loader) + if global_step <= args.last_step: # in the case of resuming, last_step > 0 global_step += 1 @@ -445,6 +457,28 @@ def train( if local_rank == 0: elapsed_time = time.time() - start overall_throughput = args.samples_per_gpu * world_size / elapsed_time + + # moving averages + avg_throughput = ( + stats_momentum * avg_throughput + + (1 - stats_momentum) * overall_throughput + ) + avg_time_per_step = ( + stats_momentum * avg_time_per_step + + (1 - stats_momentum) * elapsed_time + ) + + # bias-correction so initial values dont tend towards 0 + corrected_avg_throughput = avg_throughput / ( + 1 - (stats_momentum**global_step) + ) + corrected_avg_time_per_step = avg_time_per_step / ( + 1 - (stats_momentum**global_step) + ) + + # now we can estimate the estimated epoch length + length_per_epoch = corrected_avg_time_per_step * num_batches + current_lr = lr_scheduler.get_last_lr()[0] cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3) cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] @@ -468,6 +502,13 @@ def train( "step": global_step, "rank": torch.distributed.get_rank(), "overall_throughput": overall_throughput, + "avg_overall_throughput": avg_throughput, + "corrected_avg_overall_throughput": corrected_avg_throughput, + "elapsed_time": elapsed_time, + "avg_elapsed_time": avg_time_per_step, + "corrected_avg_elapsed_time": corrected_avg_time_per_step, + "length_per_epoch": length_per_epoch / 3600, + "num_batches": num_batches, "lr": current_lr, "cuda_mem_allocated": cuda_mem_allocated, "cuda_malloc_retries": cuda_malloc_retries, @@ -580,19 +621,35 @@ def main(args): args.data_path, mock=args.mock_data, mock_len=args.mock_len, + mock_num_samples=args.mock_num_samples, ) try: - packing_max_batch_len, grad_accum = find_packing_max_batch_len_and_grad_accum( - num_gpus=torch.distributed.get_world_size(), - avg_sample_len=dataset.get_lengths().mean(), - effective_batch_size=args.effective_batch_size, - max_batch_len_per_gpu=args.max_batch_len, - is_padding=not (args.use_dolomite or flash_enabled), - dataset=dataset, - seed=args.seed, - ) - args.sampler = "multipack" + if args.use_multipack_v2: + packing_max_batch_len, grad_accum = ( + find_packing_max_batch_len_and_grad_accum_v2( + num_gpus=torch.distributed.get_world_size(), + avg_sample_len=dataset.get_lengths().mean(), + effective_batch_size=args.effective_batch_size, + max_batch_len_per_gpu=args.max_batch_len, + dataset=dataset, + seed=args.seed, + ) + ) + args.sampler = "multipack_v2" + else: + packing_max_batch_len, grad_accum = ( + find_packing_max_batch_len_and_grad_accum( + num_gpus=torch.distributed.get_world_size(), + avg_sample_len=dataset.get_lengths().mean(), + effective_batch_size=args.effective_batch_size, + max_batch_len_per_gpu=args.max_batch_len, + is_padding=not (args.use_dolomite or flash_enabled), + dataset=dataset, + seed=args.seed, + ) + ) + args.sampler = "multipack" except RuntimeError as e: if os.environ["LOCAL_RANK"] == "0": print(f"\033[38;5;120m{e}\033[0m") @@ -640,6 +697,11 @@ def main(args): seed=args.seed, ) + assert ( + not args.use_multipack_v2 + or (args.use_multipack_v2 and args.sampler) == "multipack_v2" + ), "multipack_v2 was enabled but is not selected" + if args.local_rank == 0: metric_logger.log_sync( { @@ -683,6 +745,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: """ Wrapper around the main training job that calls torchrun. """ + # TODO(osilkin): add a check here for multpackv2 and a padding transformers check_valid_train_args(train_args) # switch out generic tmpl for legacy tmpl if requested @@ -746,8 +809,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.mock_data: command.append("--mock_data") - if train_args.mock_len: - command.append(f"--mock_len={train_args.mock_len}") + if train_args.mock_data_len: + command.append(f"--mock_len={train_args.mock_data_len}") + if train_args.mock_num_samples: + command.append(f"--mock_num_samples={train_args.mock_num_samples}") if train_args.use_dolomite: command.append("--use_dolomite") @@ -805,11 +870,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: # FSDP Options if train_args.fsdp_options.cpu_offload_params: - command.extend( - [ - "--cpu_offload_params_fsdp", - ] - ) + command.append("--cpu_offload_params_fsdp") + + if train_args.use_multipack_v2: + command += ["--use_multipack_v2"] # specify the sharding strategy command.append( @@ -933,6 +997,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: parser.add_argument("--seed", type=int, default=42) parser.add_argument("--mock_data", action="store_true") parser.add_argument("--mock_len", type=int, default=2600) + parser.add_argument("--mock_num_samples", type=int, default=92_000) parser.add_argument( "--distributed_training_framework", type=str, @@ -1008,6 +1073,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: action="store_true", help="Use Liger kernels for training.", ) + parser.add_argument( + "--use_multipack_v2", + action="store_true", + help="Use the MultipackV2 algorithm for packing batches. This is more optimal but does not support Transformers which require Padding.", + ) args = parser.parse_args() set_random_seed(args.seed) main(args) diff --git a/src/instructlab/training/multipack_sampler.py b/src/instructlab/training/multipack_sampler.py index 71d1def2..06ba20b5 100644 --- a/src/instructlab/training/multipack_sampler.py +++ b/src/instructlab/training/multipack_sampler.py @@ -413,6 +413,7 @@ def generate_batches(self, set_stats=False): return batches + # TODO(osilkin): cache the length here def __iter__(self): batches = self.generate_batches(set_stats=True) return iter(batches) diff --git a/src/instructlab/training/multipack_sampler_v2.py b/src/instructlab/training/multipack_sampler_v2.py new file mode 100644 index 00000000..4a6018c3 --- /dev/null +++ b/src/instructlab/training/multipack_sampler_v2.py @@ -0,0 +1,254 @@ +from typing import Optional, List + +import torch.distributed as dist +from torch.utils.data import Sampler +import torch + +import numpy as np +import numba + + +def find_packing_max_batch_len_and_grad_accum( + num_gpus, + avg_sample_len, + effective_batch_size, + max_batch_len_per_gpu, + dataset, + seed, +): + """ + Calculate the minimum gradient accumulation steps required and the corresponding maximum batch length. + + This function determines the minimum number of gradient accumulation steps needed to process the + effective batch size within the constraints of the maximum batch length per GPU. It starts with + the assumption of a single step (no accumulation) and increases the number of steps until the + calculated batch length does not exceed the maximum allowed per GPU. The goal is to find the + lowest gradient accumulation that allows fitting the batch within GPU limits, ensuring efficient + utilization of computational resources. + + Parameters: + - num_gpus (int): The number of GPUs over which the batch is distributed. + - avg_sample_len (int): The average length of samples in the dataset, used to estimate batch length. + - effective_batch_size (int): The total batch size intended to be processed across all GPUs and + accumulation steps. + - max_batch_len_per_gpu (int): The maximum permissible number of tokens on each GPU to avoid memory overflow. + + Returns: + - Tuple[int, int]: A tuple where the first element is the maximum batch length that can be achieved + without exceeding the per-GPU limit, and the second element is the minimum number of gradient + accumulation steps required to maintain the effective batch size. + """ + + packing_max_batch_len = max_batch_len_per_gpu + 1 + grad_accum = 0 + while packing_max_batch_len > max_batch_len_per_gpu: + grad_accum += 1 + samples_per_minibatch = effective_batch_size / grad_accum + samples_per_gpu = samples_per_minibatch / num_gpus + if int(avg_sample_len * samples_per_gpu) < dataset.get_lengths().max(): + raise RuntimeError( + f"Effective batch size is too low for multipack sampling, max sample length={dataset.get_lengths().max()} and min packing length={int(avg_sample_len * samples_per_gpu)}. " + "Switching to naive distributed sampling." + ) + + packing_max_batch_len = int((avg_sample_len) * samples_per_gpu) + + return packing_max_batch_len, grad_accum + + +@numba.njit +def lpt_check(heap: np.ndarray, A: np.ndarray, c: int, n: int): + # LPT (Longest processing time first scheduling) + # Time: O(|A| log |A| + |A| log n) + + A = np.sort(A)[::-1] + heap.fill(0) + for size in A: + # Put into smallest element + heap[1] += size + if heap[1] > c: + return False + + # Heapify (Sink) + # https://stackoverflow.com/questions/20397674/replacing-element-in-min-heap + u = 1 + while (u << 1) <= n: + v = u << 1 # lch + rch = (u << 1) | 1 + if rch <= n and heap[rch] < heap[v]: + v = rch + + if heap[u] <= heap[v]: + break + + heap[u], heap[v] = heap[v], heap[u] + u = v + + return True + + +@numba.njit +def lpt_with_result( + heap: np.ndarray, A: np.ndarray, n: int, start_index: int, rank: int +): + # LPT (Longest processing time first scheduling) + # Time: O(|A| log |A| + |A| log n) + + result = [] + + indices = np.argsort(A)[::-1] + A = A[indices] + + heap.fill(0) + heap_id = np.arange(-1, n, dtype=A.dtype) + for idx, size in enumerate(A): + # Put into smallest element + heap[1] += size + if heap_id[1] == rank: + result.append(start_index + indices[idx]) + + # Heapify (Sink) + # https://stackoverflow.com/questions/20397674/replacing-element-in-min-heap + u = 1 + while (u << 1) <= n: + v = u << 1 # lch + rch = (u << 1) | 1 + if rch <= n and heap[rch] < heap[v]: + v = rch + + if heap[u] <= heap[v]: + break + + heap[u], heap[v] = heap[v], heap[u] + heap_id[u], heap_id[v] = heap_id[v], heap_id[u] + u = v + + return result + + +@numba.njit +def allocate( + lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int +): + # Dynamic batch allocator, binary search + LPT + # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) + + heap = np.zeros(n + 1, dtype=lengths.dtype) + + s = 0 + start_index = 0 + result = [] + + while True: + # binary search [l, r) + l = 1 + r = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") + + while r - l > 1: + m = (l + r) // 2 + if lpt_check(heap, lengths[start_index : start_index + m], c, n): + l = m + else: + r = m + + # use length l + if l < n: + break # Can't allocate each sequence to a single machine + + batch = lpt_with_result( + heap, lengths[start_index : start_index + l], n, start_index, rank + ) + + start_index += l + s = lengths_cumsum[start_index - 1] + + # add local rank + result.append(batch) + + return result, s, len(result) * c * n + + +class MultipackDistributedBatchSamplerV2(Sampler): + """Unpadded length sampling using Multipack V2, for models with quadratic attention complexity. + It also tries to evenly distribute the sequences using LPT, so that quadratic load is more balanced. + + Approximate (at most 1.33x ?) the optimal solution of the identical-machines scheduling problem, which is NP-hard. + + Time Complexity: O(n log n log k) + n = maximum number of sequences per batch, k = number of nodes + """ + + def __init__( + self, + batch_max_length: int, + lengths: List[int], + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + seed: int = 0, + ): + # Get rank + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + + self.num_replicas = num_replicas + self.rank = rank + self.seed = seed + + self.batch_max_length = batch_max_length + self.lengths = lengths + assert isinstance(self.lengths, np.ndarray) + + self.epoch = 0 + + # statistics + self.eff_total_used = 0 + self.eff_total_slots = 0 + + def set_epoch(self, epoch: int): + self.epoch = epoch + + def generate_batches(self, set_stats=False): + indices = np.random.Generator( + np.random.Philox(seed=self.seed + self.epoch) + ).permutation(len(self.lengths)) + + lengths = self.lengths[indices] + lengths_cumsum = np.cumsum(lengths) + + batches, total_used, total_slots = allocate( + lengths=lengths, + lengths_cumsum=lengths_cumsum, + rank=self.rank, + c=self.batch_max_length, + n=self.num_replicas, + ) + + batches = [indices[batch] for batch in batches] + + # statistics + if set_stats: + self.eff_total_used += total_used + self.eff_total_slots += total_slots + + return batches + + def __iter__(self): + batches = self.generate_batches(set_stats=True) + return iter(batches) + + def __len__(self): + # use the latest cached value or create one if needed. + return self.num_batches() + + def num_batches(self): + batches = self.generate_batches() + return len(batches) + + def efficiency(self): + return self.eff_total_used / self.eff_total_slots diff --git a/src/instructlab/training/token_dataset.py b/src/instructlab/training/token_dataset.py index fda9a751..54d7407a 100644 --- a/src/instructlab/training/token_dataset.py +++ b/src/instructlab/training/token_dataset.py @@ -10,6 +10,7 @@ import torch # First Party +from instructlab.training.multipack_sampler_v2 import MultipackDistributedBatchSamplerV2 from instructlab.training.multipack_sampler import MultipackDistributedBatchSampler from instructlab.training.utils import log_rank_0, make_collate_fn @@ -47,12 +48,12 @@ def get_lengths(self): class MockDataset(Dataset): - def __init__(self, data_path, max_seq_len=4600): + def __init__(self, data_path, max_seq_len=4600, num_samples=92_000): self.input_ids = np.random.randint( - 0, 10000, size=(92000, max_seq_len), dtype=np.int16 + 0, 10000, size=(num_samples, max_seq_len), dtype=np.int16 ) self.labels = np.random.randint( - 0, 10000, size=(92000, max_seq_len), dtype=np.int16 + 0, 10000, size=(num_samples, max_seq_len), dtype=np.int16 ) def __len__(self): @@ -77,10 +78,13 @@ def setup_dataset( data_path: str, mock: bool = False, mock_len: int = 2600, + mock_num_samples: int = 92_000, ) -> Dataset: if mock: log_rank_0("Using a mock dataset.") - dataset = MockDataset(data_path, max_seq_len=mock_len) + dataset = MockDataset( + data_path, max_seq_len=mock_len, num_samples=mock_num_samples + ) else: dataset = TokenDataset(data_path) return dataset @@ -109,7 +113,7 @@ def setup_dataloader( lengths = dataset.get_lengths() if sampler == "multipack": - sampler = MultipackDistributedBatchSampler( + sampler_obj = MultipackDistributedBatchSampler( batch_max_length=packing_max_batch_len, lengths=lengths, num_replicas=world_size, @@ -117,16 +121,26 @@ def setup_dataloader( seed=seed, padding=not flash_enabled, ) - sampler = {"batch_sampler": sampler} + sampler_config = {"batch_sampler": sampler_obj} + elif sampler == "multipack_v2": + sampler_obj = MultipackDistributedBatchSamplerV2( + batch_max_length=packing_max_batch_len, + lengths=lengths, + num_replicas=world_size, + rank=rank, + seed=seed, + ) + sampler_config = {"batch_sampler": sampler_obj} + elif sampler == "distributed": # Third Party from torch.utils.data import DistributedSampler - sampler = ( + sampler_obj = ( DistributedSampler(dataset) if torch.distributed.is_initialized() else None ) - sampler = { - "sampler": sampler, + sampler_config = { + "sampler": sampler_obj, "batch_size": samples_per_gpu, } else: @@ -134,7 +148,7 @@ def setup_dataloader( dataloader = DataLoader( dataset, - **sampler, + **sampler_config, num_workers=num_workers, collate_fn=collate_fn, )