diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py
index 37f4fe58..d2459739 100644
--- a/src/instructlab/training/config.py
+++ b/src/instructlab/training/config.py
@@ -185,6 +185,7 @@ class TrainingArgs(BaseModel):
 
     mock_data: Optional[bool] = False
     mock_data_len: int = 0
+    mock_num_samples: int = 0
 
     deepspeed_options: DeepSpeedOptions = Field(
         default_factory=lambda: DeepSpeedOptions(
@@ -228,3 +229,9 @@ class TrainingArgs(BaseModel):
         default=False,
         description="Whether to use Liger kernels for training.",
     )
+
+    # TODO(osilkin): Create a better API for this, should not merge into library this way
+    use_multipack_v2: bool = Field(
+        default=False,
+        description="Use the MultipackV2 sampler which balances batches based on computational cost. Does not support Padding transformers.",
+    )
diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py
index 1266b0bd..15a81482 100644
--- a/src/instructlab/training/main_ds.py
+++ b/src/instructlab/training/main_ds.py
@@ -57,6 +57,9 @@
 from instructlab.training.multipack_sampler import (
     find_packing_max_batch_len_and_grad_accum,
 )
+from instructlab.training.multipack_sampler_v2 import (
+    find_packing_max_batch_len_and_grad_accum as find_packing_max_batch_len_and_grad_accum_v2,
+)
 from instructlab.training.setup_accelerator import setup_accelerator
 from instructlab.training.token_dataset import setup_dataloader, setup_dataset
 from instructlab.training.tokenizer_utils import setup_tokenizer
@@ -379,9 +382,15 @@ def train(
             else None
         )
 
+    # variables for tracking statistics
     global_grad_norm = None
+    stats_momentum = 0.999  # average strength is effectively 1/1000
+    avg_throughput = 0.0
+    avg_time_per_step = 0.0
+    num_batches = None
+
     for epoch in range(args.current_epoch, args.num_epochs):
-        if args.sampler in ("multipack"):
+        if args.sampler in ("multipack", "multipack_v2"):
             train_loader.batch_sampler.set_epoch(epoch)
         elif args.sampler in ("distributed"):
             train_loader.sampler.set_epoch(epoch)
@@ -393,6 +402,9 @@ def train(
 
         # blast through the batches in the train loader up to the last step within the epoch.
         for batch in train_loader:
+            if not num_batches:
+                num_batches = len(train_loader)
+
             if global_step <= args.last_step:
                 # in the case of resuming, last_step > 0
                 global_step += 1
@@ -445,6 +457,28 @@ def train(
             if local_rank == 0:
                 elapsed_time = time.time() - start
                 overall_throughput = args.samples_per_gpu * world_size / elapsed_time
+
+                # moving averages
+                avg_throughput = (
+                    stats_momentum * avg_throughput
+                    + (1 - stats_momentum) * overall_throughput
+                )
+                avg_time_per_step = (
+                    stats_momentum * avg_time_per_step
+                    + (1 - stats_momentum) * elapsed_time
+                )
+
+                # bias-correction so initial values dont tend towards 0
+                corrected_avg_throughput = avg_throughput / (
+                    1 - (stats_momentum**global_step)
+                )
+                corrected_avg_time_per_step = avg_time_per_step / (
+                    1 - (stats_momentum**global_step)
+                )
+
+                # now we can estimate the estimated epoch length
+                length_per_epoch = corrected_avg_time_per_step * num_batches
+
                 current_lr = lr_scheduler.get_last_lr()[0]
                 cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
                 cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
@@ -468,6 +502,13 @@ def train(
                         "step": global_step,
                         "rank": torch.distributed.get_rank(),
                         "overall_throughput": overall_throughput,
+                        "avg_overall_throughput": avg_throughput,
+                        "corrected_avg_overall_throughput": corrected_avg_throughput,
+                        "elapsed_time": elapsed_time,
+                        "avg_elapsed_time": avg_time_per_step,
+                        "corrected_avg_elapsed_time": corrected_avg_time_per_step,
+                        "length_per_epoch": length_per_epoch / 3600,
+                        "num_batches": num_batches,
                         "lr": current_lr,
                         "cuda_mem_allocated": cuda_mem_allocated,
                         "cuda_malloc_retries": cuda_malloc_retries,
@@ -580,19 +621,35 @@ def main(args):
         args.data_path,
         mock=args.mock_data,
         mock_len=args.mock_len,
+        mock_num_samples=args.mock_num_samples,
     )
 
     try:
-        packing_max_batch_len, grad_accum = find_packing_max_batch_len_and_grad_accum(
-            num_gpus=torch.distributed.get_world_size(),
-            avg_sample_len=dataset.get_lengths().mean(),
-            effective_batch_size=args.effective_batch_size,
-            max_batch_len_per_gpu=args.max_batch_len,
-            is_padding=not (args.use_dolomite or flash_enabled),
-            dataset=dataset,
-            seed=args.seed,
-        )
-        args.sampler = "multipack"
+        if args.use_multipack_v2:
+            packing_max_batch_len, grad_accum = (
+                find_packing_max_batch_len_and_grad_accum_v2(
+                    num_gpus=torch.distributed.get_world_size(),
+                    avg_sample_len=dataset.get_lengths().mean(),
+                    effective_batch_size=args.effective_batch_size,
+                    max_batch_len_per_gpu=args.max_batch_len,
+                    dataset=dataset,
+                    seed=args.seed,
+                )
+            )
+            args.sampler = "multipack_v2"
+        else:
+            packing_max_batch_len, grad_accum = (
+                find_packing_max_batch_len_and_grad_accum(
+                    num_gpus=torch.distributed.get_world_size(),
+                    avg_sample_len=dataset.get_lengths().mean(),
+                    effective_batch_size=args.effective_batch_size,
+                    max_batch_len_per_gpu=args.max_batch_len,
+                    is_padding=not (args.use_dolomite or flash_enabled),
+                    dataset=dataset,
+                    seed=args.seed,
+                )
+            )
+            args.sampler = "multipack"
     except RuntimeError as e:
         if os.environ["LOCAL_RANK"] == "0":
             print(f"\033[38;5;120m{e}\033[0m")
@@ -640,6 +697,11 @@ def main(args):
             seed=args.seed,
         )
 
+    assert (
+        not args.use_multipack_v2
+        or (args.use_multipack_v2 and args.sampler) == "multipack_v2"
+    ), "multipack_v2 was enabled but is not selected"
+
     if args.local_rank == 0:
         metric_logger.log_sync(
             {
@@ -683,6 +745,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
     """
     Wrapper around the main training job that calls torchrun.
     """
+    # TODO(osilkin): add a check here for multpackv2 and a padding transformers
     check_valid_train_args(train_args)
 
     # switch out generic tmpl for legacy tmpl if requested
@@ -746,8 +809,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
 
     if train_args.mock_data:
         command.append("--mock_data")
-        if train_args.mock_len:
-            command.append(f"--mock_len={train_args.mock_len}")
+        if train_args.mock_data_len:
+            command.append(f"--mock_len={train_args.mock_data_len}")
+        if train_args.mock_num_samples:
+            command.append(f"--mock_num_samples={train_args.mock_num_samples}")
 
     if train_args.use_dolomite:
         command.append("--use_dolomite")
@@ -805,11 +870,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
 
     # FSDP Options
     if train_args.fsdp_options.cpu_offload_params:
-        command.extend(
-            [
-                "--cpu_offload_params_fsdp",
-            ]
-        )
+        command.append("--cpu_offload_params_fsdp")
+
+    if train_args.use_multipack_v2:
+        command += ["--use_multipack_v2"]
 
     # specify the sharding strategy
     command.append(
@@ -933,6 +997,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
     parser.add_argument("--seed", type=int, default=42)
     parser.add_argument("--mock_data", action="store_true")
     parser.add_argument("--mock_len", type=int, default=2600)
+    parser.add_argument("--mock_num_samples", type=int, default=92_000)
     parser.add_argument(
         "--distributed_training_framework",
         type=str,
@@ -1008,6 +1073,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
         action="store_true",
         help="Use Liger kernels for training.",
     )
+    parser.add_argument(
+        "--use_multipack_v2",
+        action="store_true",
+        help="Use the MultipackV2 algorithm for packing batches. This is more optimal but does not support Transformers which require Padding.",
+    )
     args = parser.parse_args()
     set_random_seed(args.seed)
     main(args)
diff --git a/src/instructlab/training/multipack_sampler.py b/src/instructlab/training/multipack_sampler.py
index 71d1def2..06ba20b5 100644
--- a/src/instructlab/training/multipack_sampler.py
+++ b/src/instructlab/training/multipack_sampler.py
@@ -413,6 +413,7 @@ def generate_batches(self, set_stats=False):
 
         return batches
 
+    # TODO(osilkin): cache the length here
     def __iter__(self):
         batches = self.generate_batches(set_stats=True)
         return iter(batches)
diff --git a/src/instructlab/training/multipack_sampler_v2.py b/src/instructlab/training/multipack_sampler_v2.py
new file mode 100644
index 00000000..4a6018c3
--- /dev/null
+++ b/src/instructlab/training/multipack_sampler_v2.py
@@ -0,0 +1,254 @@
+from typing import Optional, List
+
+import torch.distributed as dist
+from torch.utils.data import Sampler
+import torch
+
+import numpy as np
+import numba
+
+
+def find_packing_max_batch_len_and_grad_accum(
+    num_gpus,
+    avg_sample_len,
+    effective_batch_size,
+    max_batch_len_per_gpu,
+    dataset,
+    seed,
+):
+    """
+    Calculate the minimum gradient accumulation steps required and the corresponding maximum batch length.
+
+    This function determines the minimum number of gradient accumulation steps needed to process the
+    effective batch size within the constraints of the maximum batch length per GPU. It starts with
+    the assumption of a single step (no accumulation) and increases the number of steps until the
+    calculated batch length does not exceed the maximum allowed per GPU. The goal is to find the
+    lowest gradient accumulation that allows fitting the batch within GPU limits, ensuring efficient
+    utilization of computational resources.
+
+    Parameters:
+    - num_gpus (int): The number of GPUs over which the batch is distributed.
+    - avg_sample_len (int): The average length of samples in the dataset, used to estimate batch length.
+    - effective_batch_size (int): The total batch size intended to be processed across all GPUs and
+      accumulation steps.
+    - max_batch_len_per_gpu (int): The maximum permissible number of tokens on each GPU to avoid memory overflow.
+
+    Returns:
+    - Tuple[int, int]: A tuple where the first element is the maximum batch length that can be achieved
+      without exceeding the per-GPU limit, and the second element is the minimum number of gradient
+      accumulation steps required to maintain the effective batch size.
+    """
+
+    packing_max_batch_len = max_batch_len_per_gpu + 1
+    grad_accum = 0
+    while packing_max_batch_len > max_batch_len_per_gpu:
+        grad_accum += 1
+        samples_per_minibatch = effective_batch_size / grad_accum
+        samples_per_gpu = samples_per_minibatch / num_gpus
+        if int(avg_sample_len * samples_per_gpu) < dataset.get_lengths().max():
+            raise RuntimeError(
+                f"Effective batch size is too low for multipack sampling, max sample length={dataset.get_lengths().max()} and min packing length={int(avg_sample_len * samples_per_gpu)}. "
+                "Switching to naive distributed sampling."
+            )
+
+        packing_max_batch_len = int((avg_sample_len) * samples_per_gpu)
+
+    return packing_max_batch_len, grad_accum
+
+
+@numba.njit
+def lpt_check(heap: np.ndarray, A: np.ndarray, c: int, n: int):
+    # LPT (Longest processing time first scheduling)
+    # Time: O(|A| log |A| + |A| log n)
+
+    A = np.sort(A)[::-1]
+    heap.fill(0)
+    for size in A:
+        # Put into smallest element
+        heap[1] += size
+        if heap[1] > c:
+            return False
+
+        # Heapify (Sink)
+        # https://stackoverflow.com/questions/20397674/replacing-element-in-min-heap
+        u = 1
+        while (u << 1) <= n:
+            v = u << 1  # lch
+            rch = (u << 1) | 1
+            if rch <= n and heap[rch] < heap[v]:
+                v = rch
+
+            if heap[u] <= heap[v]:
+                break
+
+            heap[u], heap[v] = heap[v], heap[u]
+            u = v
+
+    return True
+
+
+@numba.njit
+def lpt_with_result(
+    heap: np.ndarray, A: np.ndarray, n: int, start_index: int, rank: int
+):
+    # LPT (Longest processing time first scheduling)
+    # Time: O(|A| log |A| + |A| log n)
+
+    result = []
+
+    indices = np.argsort(A)[::-1]
+    A = A[indices]
+
+    heap.fill(0)
+    heap_id = np.arange(-1, n, dtype=A.dtype)
+    for idx, size in enumerate(A):
+        # Put into smallest element
+        heap[1] += size
+        if heap_id[1] == rank:
+            result.append(start_index + indices[idx])
+
+        # Heapify (Sink)
+        # https://stackoverflow.com/questions/20397674/replacing-element-in-min-heap
+        u = 1
+        while (u << 1) <= n:
+            v = u << 1  # lch
+            rch = (u << 1) | 1
+            if rch <= n and heap[rch] < heap[v]:
+                v = rch
+
+            if heap[u] <= heap[v]:
+                break
+
+            heap[u], heap[v] = heap[v], heap[u]
+            heap_id[u], heap_id[v] = heap_id[v], heap_id[u]
+            u = v
+
+    return result
+
+
+@numba.njit
+def allocate(
+    lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
+):
+    # Dynamic batch allocator, binary search + LPT
+    # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
+
+    heap = np.zeros(n + 1, dtype=lengths.dtype)
+
+    s = 0
+    start_index = 0
+    result = []
+
+    while True:
+        # binary search [l, r)
+        l = 1
+        r = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
+
+        while r - l > 1:
+            m = (l + r) // 2
+            if lpt_check(heap, lengths[start_index : start_index + m], c, n):
+                l = m
+            else:
+                r = m
+
+        # use length l
+        if l < n:
+            break  # Can't allocate each sequence to a single machine
+
+        batch = lpt_with_result(
+            heap, lengths[start_index : start_index + l], n, start_index, rank
+        )
+
+        start_index += l
+        s = lengths_cumsum[start_index - 1]
+
+        # add local rank
+        result.append(batch)
+
+    return result, s, len(result) * c * n
+
+
+class MultipackDistributedBatchSamplerV2(Sampler):
+    """Unpadded length sampling using Multipack V2, for models with quadratic attention complexity.
+    It also tries to evenly distribute the sequences using LPT, so that quadratic load is more balanced.
+
+    Approximate (at most 1.33x ?) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
+
+    Time Complexity: O(n log n log k)
+    n = maximum number of sequences per batch, k = number of nodes
+    """
+
+    def __init__(
+        self,
+        batch_max_length: int,
+        lengths: List[int],
+        num_replicas: Optional[int] = None,
+        rank: Optional[int] = None,
+        seed: int = 0,
+    ):
+        # Get rank
+        if num_replicas is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            num_replicas = dist.get_world_size()
+        if rank is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            rank = dist.get_rank()
+
+        self.num_replicas = num_replicas
+        self.rank = rank
+        self.seed = seed
+
+        self.batch_max_length = batch_max_length
+        self.lengths = lengths
+        assert isinstance(self.lengths, np.ndarray)
+
+        self.epoch = 0
+
+        # statistics
+        self.eff_total_used = 0
+        self.eff_total_slots = 0
+
+    def set_epoch(self, epoch: int):
+        self.epoch = epoch
+
+    def generate_batches(self, set_stats=False):
+        indices = np.random.Generator(
+            np.random.Philox(seed=self.seed + self.epoch)
+        ).permutation(len(self.lengths))
+
+        lengths = self.lengths[indices]
+        lengths_cumsum = np.cumsum(lengths)
+
+        batches, total_used, total_slots = allocate(
+            lengths=lengths,
+            lengths_cumsum=lengths_cumsum,
+            rank=self.rank,
+            c=self.batch_max_length,
+            n=self.num_replicas,
+        )
+
+        batches = [indices[batch] for batch in batches]
+
+        # statistics
+        if set_stats:
+            self.eff_total_used += total_used
+            self.eff_total_slots += total_slots
+
+        return batches
+
+    def __iter__(self):
+        batches = self.generate_batches(set_stats=True)
+        return iter(batches)
+
+    def __len__(self):
+        # use the latest cached value or create one if needed.
+        return self.num_batches()
+
+    def num_batches(self):
+        batches = self.generate_batches()
+        return len(batches)
+
+    def efficiency(self):
+        return self.eff_total_used / self.eff_total_slots
diff --git a/src/instructlab/training/token_dataset.py b/src/instructlab/training/token_dataset.py
index fda9a751..54d7407a 100644
--- a/src/instructlab/training/token_dataset.py
+++ b/src/instructlab/training/token_dataset.py
@@ -10,6 +10,7 @@
 import torch
 
 # First Party
+from instructlab.training.multipack_sampler_v2 import MultipackDistributedBatchSamplerV2
 from instructlab.training.multipack_sampler import MultipackDistributedBatchSampler
 from instructlab.training.utils import log_rank_0, make_collate_fn
 
@@ -47,12 +48,12 @@ def get_lengths(self):
 
 
 class MockDataset(Dataset):
-    def __init__(self, data_path, max_seq_len=4600):
+    def __init__(self, data_path, max_seq_len=4600, num_samples=92_000):
         self.input_ids = np.random.randint(
-            0, 10000, size=(92000, max_seq_len), dtype=np.int16
+            0, 10000, size=(num_samples, max_seq_len), dtype=np.int16
         )
         self.labels = np.random.randint(
-            0, 10000, size=(92000, max_seq_len), dtype=np.int16
+            0, 10000, size=(num_samples, max_seq_len), dtype=np.int16
         )
 
     def __len__(self):
@@ -77,10 +78,13 @@ def setup_dataset(
     data_path: str,
     mock: bool = False,
     mock_len: int = 2600,
+    mock_num_samples: int = 92_000,
 ) -> Dataset:
     if mock:
         log_rank_0("Using a mock dataset.")
-        dataset = MockDataset(data_path, max_seq_len=mock_len)
+        dataset = MockDataset(
+            data_path, max_seq_len=mock_len, num_samples=mock_num_samples
+        )
     else:
         dataset = TokenDataset(data_path)
     return dataset
@@ -109,7 +113,7 @@ def setup_dataloader(
 
     lengths = dataset.get_lengths()
     if sampler == "multipack":
-        sampler = MultipackDistributedBatchSampler(
+        sampler_obj = MultipackDistributedBatchSampler(
             batch_max_length=packing_max_batch_len,
             lengths=lengths,
             num_replicas=world_size,
@@ -117,16 +121,26 @@ def setup_dataloader(
             seed=seed,
             padding=not flash_enabled,
         )
-        sampler = {"batch_sampler": sampler}
+        sampler_config = {"batch_sampler": sampler_obj}
+    elif sampler == "multipack_v2":
+        sampler_obj = MultipackDistributedBatchSamplerV2(
+            batch_max_length=packing_max_batch_len,
+            lengths=lengths,
+            num_replicas=world_size,
+            rank=rank,
+            seed=seed,
+        )
+        sampler_config = {"batch_sampler": sampler_obj}
+
     elif sampler == "distributed":
         # Third Party
         from torch.utils.data import DistributedSampler
 
-        sampler = (
+        sampler_obj = (
             DistributedSampler(dataset) if torch.distributed.is_initialized() else None
         )
-        sampler = {
-            "sampler": sampler,
+        sampler_config = {
+            "sampler": sampler_obj,
             "batch_size": samples_per_gpu,
         }
     else:
@@ -134,7 +148,7 @@ def setup_dataloader(
 
     dataloader = DataLoader(
         dataset,
-        **sampler,
+        **sampler_config,
         num_workers=num_workers,
         collate_fn=collate_fn,
     )