diff --git a/diffusion_models/utils/datasets.py b/diffusion_models/utils/datasets.py index 687569d..e0395a4 100644 --- a/diffusion_models/utils/datasets.py +++ b/diffusion_models/utils/datasets.py @@ -43,6 +43,7 @@ def __init__(self, root: str, size: int=128) -> None: slices = file["reconstruction_rss"].shape[0] for i in range(slices): self.imgs.append({"file_name":file_name, "index":i}) + file.close() self.transform = Compose([ToTensor(), Resize((size, size), antialias=True)]) def __len__(self): @@ -61,7 +62,7 @@ def __getitem__(self, index) -> Any: class FastMRIDebug(FastMRIBrainTrain): def __len__(self): - return 512 + return 128 class QuarterFastMRI(FastMRIBrainTrain): """only every 4th image of original dataset""" diff --git a/diffusion_models/utils/trainer.py b/diffusion_models/utils/trainer.py index 8580c39..b2b5ea2 100644 --- a/diffusion_models/utils/trainer.py +++ b/diffusion_models/utils/trainer.py @@ -125,13 +125,13 @@ def _run_epoch(self, epoch: int): if self.lr_scheduler is not None: self.lr_scheduler.step(epoch + i / len(self.train_data)) if self.log_wandb: - wandb.log({"learning_rate": self.lr_scheduler.get_last_lr()}, commit=False) + wandb.log({"learning_rate": self.lr_scheduler.get_last_lr()[0]}, commit=False) epoch_losses.append(batch_loss) if self.log_wandb: wandb.log({"epoch": epoch, "loss": batch_loss, "batch_time": time()-batch_time1}) # only logging if self.log_wandb: - wandb.log({"epoch_loss": np.mean(epoch_losses), "epoch_time": time()-epoch_time1}) + wandb.log({"epoch_loss": np.mean(epoch_losses), "epoch_time": time()-epoch_time1}, commit=False) self.loss_history.append(np.mean(epoch_losses)) output = f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {self.batch_size} | Steps: {len(self.train_data)} | Loss: {np.mean(epoch_losses):.5f} | Time: {time()-epoch_time1:.2f}s" if self.device_type == "cuda": diff --git a/tests/train_generative.py b/tests/train_generative.py index 9dbb76c..41f13e3 100644 --- a/tests/train_generative.py +++ b/tests/train_generative.py @@ -12,43 +12,47 @@ import torch.multiprocessing as mp import os from utils.mp_setup import DDP_Proc_Group -from utils.datasets import MNISTTrainDataset, Cifar10Dataset, MNISTDebugDataset, Cifar10DebugDataset, FastMRIDebug, FastMRIBrainTrain, QuarterFastMRI +from utils.datasets import MNISTTrainDataset, MNISTDebugDataset, FastMRIDebug, FastMRIBrainTrain, QuarterFastMRI, Cifar10DebugDataset from utils.helpers import dotdict import wandb import torch.nn.functional as F from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts config = dotdict( + world_size = 2, total_epochs = 100, log_wandb = True, project = "fastMRI_gen_trials", - data_path = "/itet-stor/peerli/bmicdatasets-originals/Originals/fastMRI/brain/multicoil_train", + #data_path = "/itet-stor/peerli/bmicdatasets-originals/Originals/fastMRI/brain/multicoil_train", #data_path = "/itet-stor/peerli/net_scratch", - checkpoint_folder = "/itet-stor/peerli/net_scratch/run_name", # append wandb run name to this path - wandb_dir = "/itet-stor/peerli/net_scratch", + data_path = "/home/lionel/Data/fastmri/multicoil_train", + #checkpoint_folder = "/itet-stor/peerli/net_scratch/run_name", # append wandb run name to this path + checkpoint_folder = "/home/lionel/Data/checkpoints/run_name", + #wandb_dir = "/itet-stor/peerli/net_scratch", + wandb_dir = "/home/lionel/Data", from_checkpoint = False, #"/itet-stor/peerli/net_scratch/super-rain-7/checkpoint490.pt", loss_func = F.mse_loss, mixed_precision = True, optimizer = torch.optim.AdamW, lr_scheduler = "cosine_ann_warm", cosine_ann_T_0 = 3, - save_every = 1, + save_every = 3, num_samples = 9, batch_size = 8, - gradient_accumulation_rate = 64, - learning_rate = 0.0001, + gradient_accumulation_rate = 32, + learning_rate = 0.001, img_size = 256, device_type = "cuda", in_channels = 1, dataset = FastMRIBrainTrain, architecture = DiffusionModel, backbone = UNet, - attention = False, + attention = True, attention_heads = 4, attention_ff_dim = None, unet_init_channels = 64, activation = nn.SiLU, - backbone_enc_depth = 5, + backbone_enc_depth = 6, kernel_size = 3, dropout = 0.0, forward_diff = ForwardDiffusion, @@ -62,7 +66,8 @@ ) def load_train_objs(config): - train_set = config.dataset(config.data_path, config.img_size) + #train_set = config.dataset(config.data_path, config.img_size) + train_set = config.dataset(config.data_path) model = config.architecture( backbone = config.backbone( num_encoding_blocks = config.backbone_enc_depth, @@ -130,7 +135,10 @@ def training(rank, world_size, config): if __name__ == "__main__": if config.device_type == "cuda": - world_size = torch.cuda.device_count() + if "world_size" in config.keys(): + world_size = config.world_size + else: + world_size = torch.cuda.device_count() print("Device Count:", world_size) mp.spawn(DDP_Proc_Group(training), args=(world_size, config), nprocs=world_size) else: