Skip to content

Commit 31b54a6

Browse files
authored
🌊 Add error for iterable datasets in GRPOTrainer (#3216)
1 parent 17e33cd commit 31b54a6

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

trl/trainer/grpo_trainer.py

+13
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,19 @@ def data_collator(features): # No data collation is needed in GRPO
410410
self.use_vllm = args.use_vllm
411411
self.use_liger_loss = args.use_liger_loss
412412

413+
# Datasets
414+
if (
415+
isinstance(train_dataset, IterableDataset)
416+
or isinstance(eval_dataset, IterableDataset)
417+
or (
418+
isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values())
419+
)
420+
):
421+
# See https://github.com/huggingface/trl/issues/3213
422+
raise NotImplementedError(
423+
"Iterable datasets are not yet supported in GRPOTrainer. Please use a standard dataset instead."
424+
)
425+
413426
# Multi-step
414427
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
415428
self.epsilon_low = args.epsilon

0 commit comments

Comments
 (0)