diff --git a/tinyllava/train/tinyllava_trainer.py b/tinyllava/train/tinyllava_trainer.py index ca26c59..6164a84 100644 --- a/tinyllava/train/tinyllava_trainer.py +++ b/tinyllava/train/tinyllava_trainer.py @@ -43,12 +43,12 @@ def split_to_even_chunks(indices, lengths, num_chunks): def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): # We need to use torch for the random part as a distributed sampler will set the random seed for torch. assert all(l != 0 for l in lengths), "Should not have zero length." + if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): + # all samples are in the same modality + return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) - assert len(mm_indices) > 0, "Should have at least one multimodal sample." - assert len(lang_indices) > 0, "Should have at least one language sample." - mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] megabatch_size = world_size * batch_size