forked from saiprasath21/TransRecG
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
35 lines (26 loc) · 805 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from src.data.process_data import get_data
from src.data.movie_dataset import MovieDataset
from src.config import *
#get the data
get_data()
#create datasets
train_dataset = MovieDataset("src/data/train.csv")
val_dataset = MovieDataset("src/data/validation.csv")
test_dataset = MovieDataset("src/data/test.csv")
#create dataloaders
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=TRAIN_BATCH_SIZE,
shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
batch_size=VALIDATION_BATCH_SIZE,
shuffle=True
)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=TEST_BATCH_SIZE,
shuffle=True
)