Skip to content

Commit 9836419

Browse files
Add custom collate function for Getting Started example (resolves the collate_fn TypeError) (#607)
* fix: getting started default example by adding custom collate function for StreamingDataLoader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ab9613e commit 9836419

File tree

3 files changed

+37
-6
lines changed

3 files changed

+37
-6
lines changed

README.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,19 @@ Load the data by replacing the PyTorch DataSet and DataLoader with the Streaming
127127
```python
128128
import litdata as ld
129129

130-
train_dataset = ld.StreamingDataset('s3://my-bucket/fast_data', shuffle=True, drop_last=True)
131-
train_dataloader = ld.StreamingDataLoader(train_dataset)
130+
dataset = ld.StreamingDataset('s3://my-bucket/fast_data', shuffle=True, drop_last=True)
132131

133-
for sample in train_dataloader:
134-
img, cls = sample['image'], sample['class']
132+
# Custom collate function to handle the batch (Optional)
133+
def collate_fn(batch):
134+
return {
135+
"image": [sample["image"] for sample in batch],
136+
"class": [sample["class"] for sample in batch],
137+
}
138+
139+
140+
dataloader = ld.StreamingDataLoader(dataset, collate_fn=collate_fn)
141+
for sample in dataloader:
142+
img, cls = sample["image"], sample["class"]
135143
```
136144

137145
**Key benefits:**

examples/getting_started/README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,16 @@ img = sample['image']
6464
cls = sample['class']
6565

6666
# Create dataLoader and iterate over it to train your AI models.
67-
dataloader = StreamingDataLoader(dataset)
67+
68+
# Custom collate function to handle the batch (Optional)
69+
def collate_fn(batch):
70+
return {
71+
"image": [sample["image"] for sample in batch],
72+
"class": [sample["class"] for sample in batch],
73+
}
74+
75+
76+
dataloader = StreamingDataLoader(dataset, collate_fn=collate_fn)
77+
for sample in dataloader:
78+
img, cls = sample["image"], sample["class"]
6879
```

examples/getting_started/stream.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,16 @@
1212
cls = sample["class"]
1313

1414
# Create dataLoader and iterate over it to train your AI models.
15-
dataloader = StreamingDataLoader(dataset)
15+
16+
17+
# Custom collate function to handle the batch (Optional)
18+
def collate_fn(batch):
19+
return {
20+
"image": [sample["image"] for sample in batch],
21+
"class": [sample["class"] for sample in batch],
22+
}
23+
24+
25+
dataloader = StreamingDataLoader(dataset, collate_fn=collate_fn)
26+
for sample in dataloader:
27+
img, cls = sample["image"], sample["class"]

0 commit comments

Comments
 (0)