Skip to content

feat: use a shared queue across workers for data processing #559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ clean:
rm -rf ./dist

install-dependencies:
pip install -U lightning-sdk
pip install -r requirements.txt
pip install -r requirements/test.txt
pip install -r requirements/docs.txt
Expand Down
197 changes: 104 additions & 93 deletions src/litdata/processing/data_processor.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
def prepare_structure(self, _: Optional[str]) -> Any:
return self._inputs

def prepare_item(self, item_metadata: Any, output_dir: str, is_last: bool) -> None:
def prepare_item(self, item_metadata: Any, output_dir: str, is_last: bool = False) -> None:
if self._contains_device and self._device is None:
self._find_device()

Expand Down
3 changes: 3 additions & 0 deletions src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def _next_data(self) -> Any:
data = None
while data is None:
data = super()._next_data()
print(f"dataloader iter patch => data: {data}")
return data
except StopIteration as e:
raise e
Expand Down Expand Up @@ -491,6 +492,7 @@ def _try_put_index(self) -> None:
self._task_info[self._send_idx] = (worker_queue_idx,)
self._tasks_outstanding += 1
self._send_idx += 1
print(f"{self._send_idx=}, {worker_queue_idx=}, {index=}, {worker_queue_idx=} {self._tasks_outstanding=}")
else:
super()._try_put_index()

Expand Down Expand Up @@ -720,6 +722,7 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None:

# Used to restart on the next DataLoader worker from the previous run.
self._latest_worker_idx = obj["latest_worker_idx"] + 1
print(f"restarting dataloader from statedict with {self._latest_worker_idx=}")
self._worker_idx_iter = iter(self._worker_idx)
for _ in range(self._latest_worker_idx):
next(self._worker_idx_iter)
Expand Down
1 change: 0 additions & 1 deletion src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,6 @@ def __next__(self) -> Any:
self.has_triggered_download = True
self.global_index += 1 # total number of samples processed by the current worker
self.consumed_sample_count_in_curr_chunk += 1 # number of samples processed in the current chunk

return data

def state_dict(self, num_samples_yielded: int, num_workers: int, batch_size: int) -> Dict[str, Any]:
Expand Down
1 change: 1 addition & 0 deletions src/litdata/streaming/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def serialize(self, items: Any) -> Tuple[bytes, Optional[int]]:

worker_rank = get_worker_rank()
if worker_rank is not None:
print(flush=True) # to prevent truncated printing when using concurrent threads/processes
print(f"Rank {worker_rank} inferred the following `{data_format}` data format.")
self._data_format = data_format
self._data_spec = data_spec
Expand Down
267 changes: 122 additions & 145 deletions tests/processing/test_data_processor.py

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions tests/processing/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_optimize_append_overwrite(tmpdir):
ds = StreamingDataset(output_dir)

assert len(ds) == 5
assert ds[:] == [(i, i**2) for i in range(5)]
assert sorted(ds[:]) == [(i, i**2) for i in range(5)]

with pytest.raises(RuntimeError, match="HINT: If you want to append/overwrite to the existing dataset"):
optimize(
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_optimize_append_overwrite(tmpdir):
ds = StreamingDataset(output_dir)

assert len(ds) == 5
assert ds[:] == [(i, i**2) for i in range(5, 10)]
assert sorted(ds[:]) == [(i, i**2) for i in range(5, 10)] # each worker can pick items in any order

optimize(
fn=compress,
Expand All @@ -143,7 +143,7 @@ def test_optimize_append_overwrite(tmpdir):
ds = StreamingDataset(output_dir)

assert len(ds) == 10
assert ds[:] == [(i, i**2) for i in range(5, 15)]
assert sorted(ds[:]) == [(i, i**2) for i in range(5, 15)]

optimize(
fn=compress,
Expand All @@ -157,7 +157,7 @@ def test_optimize_append_overwrite(tmpdir):
ds = StreamingDataset(output_dir)

assert len(ds) == 15
assert ds[:] == [(i, i**2) for i in range(5, 20)]
assert sorted(ds[:]) == [(i, i**2) for i in range(5, 20)]

with pytest.raises(Exception, match="The config isn't consistent between chunks"):
optimize(
Expand All @@ -181,7 +181,7 @@ def test_optimize_append_overwrite(tmpdir):
ds = StreamingDataset(output_dir)

assert len(ds) == 5
assert ds[:] == [(i, i**2, i**3) for i in range(0, 5)]
assert sorted(ds[:]) == [(i, i**2, i**3) for i in range(0, 5)]


@pytest.mark.skipif(sys.platform == "win32", reason="too slow")
Expand Down Expand Up @@ -216,7 +216,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
ds = StreamingDataset(output_dir)

assert len(ds) == 4
assert ds[:] == [(i, i**2) for i in range(4)]
assert sorted(ds[:]) == [(i, i**2) for i in range(4)] # for multiple workers, the order of items is not guaranteed
# checkpoints should be deleted
assert not os.path.exists(os.path.join(output_dir, ".checkpoints"))

Expand Down Expand Up @@ -257,7 +257,7 @@ def test_optimize_checkpoint_in_none_and_append_mode(tmpdir):
ds = StreamingDataset(output_dir)

assert len(ds) == 8
assert ds[:] == [(i, i**2) for i in range(8)]
assert sorted(ds[:]) == [(i, i**2) for i in range(8)]
# checkpoints should be deleted
assert not os.path.exists(os.path.join(output_dir, ".checkpoints"))

Expand Down
81 changes: 58 additions & 23 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,10 +758,12 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk
L = len(dataset)
assert L == 20

returned_data = []
for i in range(L):
sequence = dataset[i]
assert sequence[0].item() == i * block_size
assert sequence[-1].item() == (i + 1) * block_size - 1
returned_data.append((sequence[0].item(), sequence[-1].item()))
expected_data = [(i * block_size, (i + 1) * block_size - 1) for i in range(L)]
assert sorted(returned_data) == expected_data

monkeypatch.setenv("WORLD_SIZE", "2")
monkeypatch.setenv("GLOBAL_RANK", "0")
Expand All @@ -780,11 +782,13 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk
# one worker will yield 2 batches, other will yield 3 batches => len(dataloader) = 5
assert len(dataloader) == 5

expected = [[0, 10], [60, 70], [20, 30], [80, 90], [40, 50]]
returned = []
# we can't foresay the items that node 0 and node 1 will
# but, they will be different and should completely describe the dataset
# expected = [[0, 10], [60, 70], [20, 30], [80, 90], [40, 50]]
rank_0_returned = []
for batch in dataloader:
returned.append(batch[:, 0].tolist())
assert returned == expected
rank_0_returned.append(batch[:, 0].tolist())
assert len(rank_0_returned) == 5

monkeypatch.setenv("WORLD_SIZE", "2")
monkeypatch.setenv("GLOBAL_RANK", "1")
Expand All @@ -795,11 +799,15 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk

assert len(dataloader) == 5

expected = [[100, 110], [160, 170], [120, 130], [180, 190], [140, 150]]
returned = []
rank_1_returned = []
for batch in dataloader:
returned.append(batch[:, 0].tolist())
assert returned == expected
rank_1_returned.append(batch[:, 0].tolist())
assert len(rank_1_returned) == 5

returned_items = sorted(rank_0_returned + rank_1_returned)
assert len(returned_items) == 10
print(f"{returned_items=}")
assert returned_items == [[i, i + 10] for i in range(0, 200, 20)]


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
Expand Down Expand Up @@ -978,7 +986,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir, shuffle=False):

@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
@mock.patch.dict(os.environ, {}, clear=True)
@pytest.mark.timeout(60)
@pytest.mark.timeout(120)
@pytest.mark.parametrize("shuffle", [True, False])
def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch):
"""Tests resuming from a chunk past the first chunk, when subsequent chunks don't have the same size."""
Expand All @@ -989,6 +997,16 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch):
monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", optimize_data_cache_dir)
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", optimize_cache_dir)

# 8*10*10 = 800 items will be stored in chunks of max_size = 190 with 4 workers
# so, if 4 chunks of 190 items = 760 items can be packed in 4 chunks
# left chunks = 800 - 760 = 140 chunks
# these 140 chunks can be stored in any random order, so we can't predict the exact count
# but we can put a `min-max` value.
# min => 140 can be stored in a single chunk by a single worker = 4 + 1 = 5 chunks minimum
# max => 140 items can be picked by each of the 4 works = 4 chunks with (~35 items)
# (can't be 35, some will've 30 or 40)
# so, max chunk count = 4 + 4 = 8 chunks maximum

optimize(
fn=_simple_preprocess,
inputs=list(range(8)),
Expand All @@ -998,17 +1016,30 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch):
num_uploaders=1,
item_loader=TokensLoader(block_size=10),
)
assert set(os.listdir(data_dir)) == {
"chunk-0-0.bin",
"chunk-0-1.bin",
"chunk-1-0.bin",
"chunk-1-1.bin",
"chunk-2-0.bin",
"chunk-2-1.bin",
"chunk-3-0.bin",
"chunk-3-1.bin",
"index.json",
}
# print(f"{os.listdir(data_dir)=}")
# # print items in head of each
# for file_name in os.listdir(data_dir):
# file_path = os.path.join(data_dir, file_name)

# with open(file_path, "rb") as f:
# head_bytes = f.read(4) # read first 4 bytes
# if len(head_bytes) < 4:
# print(f"{file_name}: File too short")
# continue
# val = np.frombuffer(head_bytes, dtype=np.int32)[0]
# print(f"{file_name}: {val}")
assert 6 <= len(os.listdir(data_dir)) <= 9 # +1 for index.json file

# check if the dataloader contains the complete dataset
os.mkdir(s3_cache_dir)
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir, shuffle=shuffle)

fetched_dataset = []
for i, batch in enumerate(train_dataloader):
fetched_dataset.extend(batch)
assert len(fetched_dataset) == 80

shutil.rmtree(s3_cache_dir)

os.mkdir(s3_cache_dir)
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir, shuffle=shuffle)
Expand All @@ -1029,8 +1060,12 @@ def test_dataset_resume_on_future_chunks(shuffle, tmpdir, monkeypatch):
assert dataloader_state is not None
assert batch_to_resume_from is not None
train_dataloader.load_state_dict(dataloader_state)
print(f"{dataloader_state=}")
print(f"{batch_to_resume_from=}")
next_batch_data = next(iter(train_dataloader))
print(f"{next_batch_data=}")
# The next batch after resuming must match what we should have gotten next in the initial loop
assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from)
assert torch.equal(next_batch_data, batch_to_resume_from)


@pytest.mark.timeout(60)
Expand Down
Loading