-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
114 lines (95 loc) · 3.86 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from modeling import EyeGazeQAclip
import numpy as np
from torch.utils.data import Dataset,DataLoader
import torch.utils.data as data
import torch
import torch.nn.functional as F
from lightning import Trainer, seed_everything
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
import lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import os
seed_everything(0, workers=True)
class EncodeDatatrain(Dataset):
def __init__(self):
super().__init__()
self.path = []
self.length = 0
self.data = torch.load('final_data_questions2006train.pt',map_location=torch.device('cpu'))
def __getitem__(self, index):
instance = self.data[index]
questionfeatures = instance['Question']
textanswer = instance['Text_Answer']
imageanswer = instance['Image_Answer']
videoanswer = instance['Video_Answer']
questionid = instance['question_number']
label =instance['Label']
videofeature1 = instance['Video_Feature1']
videofeature2 = instance['Video_Feature2']
videofeature3 = instance['Video_Feature3']
videofeature4 = instance['Video_Feature4']
videofeature5 = instance['Video_Feature5']
samples = [questionid,questionfeatures,videofeature1,videofeature2,videofeature3,videofeature4,videofeature5,textanswer,imageanswer,videoanswer,label]
return samples
#print(batch)
def __len__(self, ):
return len(self.data)
class EncodeDatatest(Dataset):
def __init__(self):
super().__init__()
self.path = []
self.length = 0
self.data = torch.load('final_data_questions2006test.pt',map_location=torch.device('cpu'))
def __getitem__(self, index):
instance = self.data[index]
questionfeatures = instance['Question']
textanswer = instance['Text_Answer']
imageanswer = instance['Image_Answer']
videoanswer = instance['Video_Answer']
questionid = instance['question_number']
label =instance['Label']
videofeature1 = instance['Video_Feature1']
videofeature2 = instance['Video_Feature2']
videofeature3 = instance['Video_Feature3']
videofeature4 = instance['Video_Feature4']
videofeature5 = instance['Video_Feature5']
samples = [questionid,questionfeatures,videofeature1,videofeature2,videofeature3,videofeature4,videofeature5,textanswer,imageanswer,videoanswer,label]
return samples
#print(batch)
def __len__(self, ):
return len(self.data)
class EyeGazeQADataModule(pl.LightningDataModule):
def __init__(self, batch_size: int = 1024):
super().__init__()
self.batch_size = batch_size
dataset=EncodeDatatrain()
dataset2 = EncodeDatatest()
self.train_set = dataset
self.test_set = dataset2
def train_dataloader(self):
return DataLoader(self.train_set, batch_size=self.batch_size,shuffle=True,num_workers=4,pin_memory=True)
def test_dataloader(self):
return DataLoader(self.test_set, batch_size=self.batch_size,shuffle=False,num_workers=4,pin_memory=True)
seed_everything(0, workers=True)
logger = TensorBoardLogger("0621tensorboardlog", name="my_model")
trainer = Trainer(
devices = [1],
accelerator="gpu",
strategy="ddp_find_unused_parameters_false",
callbacks=[
LearningRateMonitor(logging_interval='step'),
ModelCheckpoint(save_top_k = 1)
],
benchmark=False,
deterministic=False,
logger=logger,
max_epochs=20,
default_root_dir='training_results',
check_val_every_n_epoch=1,
log_every_n_steps=1
)
model = EyeGazeQAclip()
datamoduleA = EyeGazeQADataModule()
trainer.fit(model, datamodule=datamoduleA,ckpt_path=None)
trainer.test(model,datamoduleA,ckpt_path='last')