Skip to content

Index Error when calling StreamingDataLoader.state_dict() when using custom collate_fn with multiple workers #196

Closed
@esivonxay-cognitiv

Description

@esivonxay-cognitiv

🐛 Bug

To Reproduce

Steps to reproduce the behavior:
Add this unit test to the test_dataloader.py file and run it.

def test_custom_collate_multiworker():
    dataset = TestCombinedStreamingDataset(
        [TestStatefulDatasetDict(10, 1), TestStatefulDatasetDict(10, -1)],
        42,
        weights=(0.5, 0.5),
        iterate_over_all=False,
    )
    assert dataset._datasets[0].shuffle is None
    assert dataset._datasets[1].shuffle is None
    dataloader = StreamingDataLoader(dataset, batch_size=2, num_workers=3, shuffle=True, collate_fn=custom_collate_fn)
    assert dataset._datasets[0].shuffle
    assert dataset._datasets[1].shuffle
    dataloader_iter = iter(dataloader)
    assert next(dataloader_iter) == "received"
    assert dataloader._num_samples_yielded_combined[0] == [2]
    assert next(dataloader_iter) == "received"
    assert next(dataloader_iter) == "received"
    assert next(dataloader_iter) == "received"

    dataloader.state_dict()

Expected behavior

The state_dict() method should execute without any errors.

Environment

  • PyTorch Version (e.g., 1.0): 2.3.1
  • OS (e.g., Linux): Mac
  • How you installed PyTorch (conda, pip, source): pip
  • Python version: 3.8

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions