Closed
Description
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
Add this unit test to the test_dataloader.py
file and run it.
def test_custom_collate_multiworker():
dataset = TestCombinedStreamingDataset(
[TestStatefulDatasetDict(10, 1), TestStatefulDatasetDict(10, -1)],
42,
weights=(0.5, 0.5),
iterate_over_all=False,
)
assert dataset._datasets[0].shuffle is None
assert dataset._datasets[1].shuffle is None
dataloader = StreamingDataLoader(dataset, batch_size=2, num_workers=3, shuffle=True, collate_fn=custom_collate_fn)
assert dataset._datasets[0].shuffle
assert dataset._datasets[1].shuffle
dataloader_iter = iter(dataloader)
assert next(dataloader_iter) == "received"
assert dataloader._num_samples_yielded_combined[0] == [2]
assert next(dataloader_iter) == "received"
assert next(dataloader_iter) == "received"
assert next(dataloader_iter) == "received"
dataloader.state_dict()
Expected behavior
The state_dict()
method should execute without any errors.
Environment
- PyTorch Version (e.g., 1.0): 2.3.1
- OS (e.g., Linux): Mac
- How you installed PyTorch (
conda
,pip
, source): pip - Python version: 3.8