Skip to content

Commit 76398c3

Browse files
author
Anton Andreychuk
authored
Merge pull request #4 from danissomo/main
RAM consumption fix and the quality of life changes
2 parents b5a94aa + 6b51cef commit 76398c3

File tree

3 files changed

+47
-29
lines changed

3 files changed

+47
-29
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,8 @@ venv.bak/
127127
dmypy.json
128128

129129
# Pyre type checker
130-
.pyre/
130+
.pyre/
131+
132+
133+
TransPath_data
134+
wandb

data/hmaps.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,26 +108,27 @@ class GridData(Dataset):
108108
cf - correction factor values
109109
"""
110110
def __init__(self, path, mode='f', clip_value=0.95):
111-
maps = np.load(os.path.join(path, 'maps.npy')).astype('float32')
112-
self.maps = torch.tensor(maps)
113-
goals = np.load(os.path.join(path, 'goals.npy')).astype('float32')
114-
self.goals = torch.tensor(goals)
115-
starts = np.load(os.path.join(path, 'starts.npy')).astype('float32')
116-
self.starts = torch.tensor(starts)
111+
self.clip_v = clip_value
112+
self.mode = mode
113+
114+
self.maps = np.load(os.path.join(path, 'maps.npy'), mmap_mode='c')
115+
self.goals = np.load(os.path.join(path, 'goals.npy'), mmap_mode='c')
116+
self.starts = np.load(os.path.join(path, 'starts.npy'), mmap_mode='c')
117117

118-
if mode == 'f':
119-
gt_values = np.load(os.path.join(path, 'focal.npy')).astype('float32')
120-
gt_values = torch.tensor(gt_values)
121-
self.gt_values = torch.where(gt_values >= clip_value, gt_values, torch.zeros_like(gt_values))
122-
elif mode == 'h':
123-
gt_values = np.load(os.path.join(path, 'abs.npy')).astype('float32')
124-
self.gt_values = torch.tensor(gt_values)
125-
elif mode == 'cf':
126-
gt_values = np.load(os.path.join(path, 'cf.npy')).astype('float32')
127-
self.gt_values = torch.tensor(gt_values)
128-
118+
file_gt = {'f' : 'focal.npy', 'h':'abs.npy', 'cf': 'cf.npy'}[mode]
119+
self.gt_values = np.load(os.path.join(path, file_gt), mmap_mode='c')
120+
121+
129122
def __len__(self):
130123
return len(self.gt_values)
131124

125+
126+
132127
def __getitem__(self, idx):
133-
return self.maps[idx], self.starts[idx], self.goals[idx], self.gt_values[idx]
128+
gt_ = torch.from_numpy(self.gt_values[idx].astype('float32'))
129+
if self.mode == 'f':
130+
gt_= torch.where( gt_ >= self.clip_v, gt_ , torch.zeros_like( torch.from_numpy(self.gt_values[idx])))
131+
return (torch.from_numpy(self.maps[idx].astype('float32')),
132+
torch.from_numpy(self.starts[idx].astype('float32')),
133+
torch.from_numpy(self.goals[idx].astype('float32')),
134+
gt_ )

train.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import wandb
66
from torch.utils.data import DataLoader
77
from pytorch_lightning.loggers import WandbLogger
8+
import torch
89

910
import argparse
11+
import multiprocessing
1012

1113

12-
def main(mode, run_name, proj_name):
14+
def main(mode, run_name, proj_name, batch_size, max_epochs):
1315
train_data = GridData(
1416
path='./TransPath_data/train',
1517
mode=mode
@@ -18,18 +20,25 @@ def main(mode, run_name, proj_name):
1820
path='./TransPath_data/val',
1921
mode=mode
2022
)
21-
train_dataloader = DataLoader(train_data, batch_size=256,
22-
shuffle=True, num_workers=0, pin_memory=True)
23-
val_dataloader = DataLoader(val_data, batch_size=256,
24-
shuffle=False, num_workers=0, pin_memory=True)
23+
train_dataloader = DataLoader( train_data,
24+
batch_size=batch_size,
25+
shuffle=True,
26+
num_workers=multiprocessing.cpu_count(),
27+
pin_memory=True)
28+
val_dataloader = DataLoader( val_data,
29+
batch_size=batch_size,
30+
shuffle=False,
31+
num_workers=multiprocessing.cpu_count(),
32+
pin_memory=True)
33+
2534
samples = next(iter(val_dataloader))
2635

2736
model = Autoencoder(mode=mode)
28-
wandb_logger = WandbLogger(project=proj_name, name=run_name)
37+
wandb_logger = WandbLogger(project=proj_name, name=f'{run_name}_{mode}')
2938
trainer = pl.Trainer(
3039
logger=wandb_logger,
31-
gpus=-1,
32-
max_epochs=50,
40+
accelerator="auto",
41+
max_epochs=max_epochs,
3342
deterministic=False,
3443
callbacks=[PathLogger(samples, mode=mode)],
3544
)
@@ -42,12 +51,16 @@ def main(mode, run_name, proj_name):
4251
parser.add_argument('--run_name', type=str, default='default')
4352
parser.add_argument('--proj_name', type=str, default='TransPath_runs')
4453
parser.add_argument('--seed', type=int, default=39)
54+
parser.add_argument('--batch', type=int, default=256)
55+
parser.add_argument('--epoch', type=int, default=15)
4556

4657
args = parser.parse_args()
4758
pl.seed_everything(args.seed)
48-
59+
torch.set_float32_matmul_precision('high') #fix for tesor blocks warning with new video card
4960
main(
5061
mode=args.mode,
5162
run_name=args.run_name,
52-
proj_name=args.proj_name
63+
proj_name=args.proj_name,
64+
batch_size=args.batch,
65+
max_epochs=args.epoch
5366
)

0 commit comments

Comments
 (0)