55import wandb
66from torch .utils .data import DataLoader
77from pytorch_lightning .loggers import WandbLogger
8+ import torch
89
910import argparse
11+ import multiprocessing
1012
1113
12- def main (mode , run_name , proj_name ):
14+ def main (mode , run_name , proj_name , batch_size , max_epochs ):
1315 train_data = GridData (
1416 path = './TransPath_data/train' ,
1517 mode = mode
@@ -18,18 +20,25 @@ def main(mode, run_name, proj_name):
1820 path = './TransPath_data/val' ,
1921 mode = mode
2022 )
21- train_dataloader = DataLoader (train_data , batch_size = 256 ,
22- shuffle = True , num_workers = 0 , pin_memory = True )
23- val_dataloader = DataLoader (val_data , batch_size = 256 ,
24- shuffle = False , num_workers = 0 , pin_memory = True )
23+ train_dataloader = DataLoader ( train_data ,
24+ batch_size = batch_size ,
25+ shuffle = True ,
26+ num_workers = multiprocessing .cpu_count (),
27+ pin_memory = True )
28+ val_dataloader = DataLoader ( val_data ,
29+ batch_size = batch_size ,
30+ shuffle = False ,
31+ num_workers = multiprocessing .cpu_count (),
32+ pin_memory = True )
33+
2534 samples = next (iter (val_dataloader ))
2635
2736 model = Autoencoder (mode = mode )
28- wandb_logger = WandbLogger (project = proj_name , name = run_name )
37+ wandb_logger = WandbLogger (project = proj_name , name = f' { run_name } _ { mode } ' )
2938 trainer = pl .Trainer (
3039 logger = wandb_logger ,
31- gpus = - 1 ,
32- max_epochs = 50 ,
40+ accelerator = "auto" ,
41+ max_epochs = max_epochs ,
3342 deterministic = False ,
3443 callbacks = [PathLogger (samples , mode = mode )],
3544 )
@@ -42,12 +51,16 @@ def main(mode, run_name, proj_name):
4251 parser .add_argument ('--run_name' , type = str , default = 'default' )
4352 parser .add_argument ('--proj_name' , type = str , default = 'TransPath_runs' )
4453 parser .add_argument ('--seed' , type = int , default = 39 )
54+ parser .add_argument ('--batch' , type = int , default = 256 )
55+ parser .add_argument ('--epoch' , type = int , default = 15 )
4556
4657 args = parser .parse_args ()
4758 pl .seed_everything (args .seed )
48-
59+ torch . set_float32_matmul_precision ( 'high' ) #fix for tesor blocks warning with new video card
4960 main (
5061 mode = args .mode ,
5162 run_name = args .run_name ,
52- proj_name = args .proj_name
63+ proj_name = args .proj_name ,
64+ batch_size = args .batch ,
65+ max_epochs = args .epoch
5366 )
0 commit comments