-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_module.py
88 lines (71 loc) · 2.43 KB
/
data_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from config import get_config
class DataModule(pl.LightningDataModule):
def __init__(
self,
train_ds: Dataset,
val_ds: Dataset,
):
"""
Initializes a PyTorch Lightning DataModule for handling training and validation data.
Args:
train_ds (Dataset): Training dataset.
val_ds (Dataset): Validation dataset.
"""
super().__init__()
self.train_ds = train_ds
self.val_ds = val_ds
self.config = get_config()
def prepare_data(self):
"""
Optional method for data preparation.
Typically used for downloading or preprocessing data that might affect training.
Since the datasets (`train_ds` and `val_ds`) are already prepared in `get_ds`,
this method is left empty (`pass`).
"""
pass
def setup(self, stage: str):
"""
Sets up data loaders for training or validation stages.
Args:
stage (str): Current stage ('fit' for training, 'validate' for validation).
"""
if stage == "fit":
self.train = DataLoader(
self.train_ds,
batch_size=self.config["batch_size"],
shuffle=True,
num_workers=self.config["num_workers"],
pin_memory=self.config["pin_memory"],
persistent_workers=self.config["num_workers"] != 0,
)
if stage == "validate":
self.val = DataLoader(
self.val_ds,
batch_size=1,
shuffle=True,
num_workers=self.config["num_workers"],
pin_memory=self.config["pin_memory"],
persistent_workers=self.config["num_workers"] != 0,
)
def train_dataloader(self):
"""
Returns the DataLoader for training.
Returns:
DataLoader: Training DataLoader.
"""
return self.train
def val_dataloader(self):
"""
Returns the DataLoader for validation.
Returns:
DataLoader: Validation DataLoader.
"""
return self.val
def predict_dataloader(self):
"""
Placeholder method for setting up a DataLoader for prediction (inference).
This method is not implemented in the current version.
"""
pass