Skip to content

Commit

Permalink
updates to trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Nov 14, 2023
1 parent 3d727f5 commit 7579d4d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 14 deletions.
3 changes: 2 additions & 1 deletion diffusion_models/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, root: str, size: int=128) -> None:
slices = file["reconstruction_rss"].shape[0]
for i in range(slices):
self.imgs.append({"file_name":file_name, "index":i})
file.close()
self.transform = Compose([ToTensor(), Resize((size, size), antialias=True)])

def __len__(self):
Expand All @@ -61,7 +62,7 @@ def __getitem__(self, index) -> Any:

class FastMRIDebug(FastMRIBrainTrain):
def __len__(self):
return 512
return 128

class QuarterFastMRI(FastMRIBrainTrain):
"""only every 4th image of original dataset"""
Expand Down
4 changes: 2 additions & 2 deletions diffusion_models/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ def _run_epoch(self, epoch: int):
if self.lr_scheduler is not None:
self.lr_scheduler.step(epoch + i / len(self.train_data))
if self.log_wandb:
wandb.log({"learning_rate": self.lr_scheduler.get_last_lr()}, commit=False)
wandb.log({"learning_rate": self.lr_scheduler.get_last_lr()[0]}, commit=False)
epoch_losses.append(batch_loss)
if self.log_wandb:
wandb.log({"epoch": epoch, "loss": batch_loss, "batch_time": time()-batch_time1})
# only logging
if self.log_wandb:
wandb.log({"epoch_loss": np.mean(epoch_losses), "epoch_time": time()-epoch_time1})
wandb.log({"epoch_loss": np.mean(epoch_losses), "epoch_time": time()-epoch_time1}, commit=False)
self.loss_history.append(np.mean(epoch_losses))
output = f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {self.batch_size} | Steps: {len(self.train_data)} | Loss: {np.mean(epoch_losses):.5f} | Time: {time()-epoch_time1:.2f}s"
if self.device_type == "cuda":
Expand Down
30 changes: 19 additions & 11 deletions tests/train_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,47 @@
import torch.multiprocessing as mp
import os
from utils.mp_setup import DDP_Proc_Group
from utils.datasets import MNISTTrainDataset, Cifar10Dataset, MNISTDebugDataset, Cifar10DebugDataset, FastMRIDebug, FastMRIBrainTrain, QuarterFastMRI
from utils.datasets import MNISTTrainDataset, MNISTDebugDataset, FastMRIDebug, FastMRIBrainTrain, QuarterFastMRI, Cifar10DebugDataset
from utils.helpers import dotdict
import wandb
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts

config = dotdict(
world_size = 2,
total_epochs = 100,
log_wandb = True,
project = "fastMRI_gen_trials",
data_path = "/itet-stor/peerli/bmicdatasets-originals/Originals/fastMRI/brain/multicoil_train",
#data_path = "/itet-stor/peerli/bmicdatasets-originals/Originals/fastMRI/brain/multicoil_train",
#data_path = "/itet-stor/peerli/net_scratch",
checkpoint_folder = "/itet-stor/peerli/net_scratch/run_name", # append wandb run name to this path
wandb_dir = "/itet-stor/peerli/net_scratch",
data_path = "/home/lionel/Data/fastmri/multicoil_train",
#checkpoint_folder = "/itet-stor/peerli/net_scratch/run_name", # append wandb run name to this path
checkpoint_folder = "/home/lionel/Data/checkpoints/run_name",
#wandb_dir = "/itet-stor/peerli/net_scratch",
wandb_dir = "/home/lionel/Data",
from_checkpoint = False, #"/itet-stor/peerli/net_scratch/super-rain-7/checkpoint490.pt",
loss_func = F.mse_loss,
mixed_precision = True,
optimizer = torch.optim.AdamW,
lr_scheduler = "cosine_ann_warm",
cosine_ann_T_0 = 3,
save_every = 1,
save_every = 3,
num_samples = 9,
batch_size = 8,
gradient_accumulation_rate = 64,
learning_rate = 0.0001,
gradient_accumulation_rate = 32,
learning_rate = 0.001,
img_size = 256,
device_type = "cuda",
in_channels = 1,
dataset = FastMRIBrainTrain,
architecture = DiffusionModel,
backbone = UNet,
attention = False,
attention = True,
attention_heads = 4,
attention_ff_dim = None,
unet_init_channels = 64,
activation = nn.SiLU,
backbone_enc_depth = 5,
backbone_enc_depth = 6,
kernel_size = 3,
dropout = 0.0,
forward_diff = ForwardDiffusion,
Expand All @@ -62,7 +66,8 @@
)

def load_train_objs(config):
train_set = config.dataset(config.data_path, config.img_size)
#train_set = config.dataset(config.data_path, config.img_size)
train_set = config.dataset(config.data_path)
model = config.architecture(
backbone = config.backbone(
num_encoding_blocks = config.backbone_enc_depth,
Expand Down Expand Up @@ -130,7 +135,10 @@ def training(rank, world_size, config):

if __name__ == "__main__":
if config.device_type == "cuda":
world_size = torch.cuda.device_count()
if "world_size" in config.keys():
world_size = config.world_size
else:
world_size = torch.cuda.device_count()
print("Device Count:", world_size)
mp.spawn(DDP_Proc_Group(training), args=(world_size, config), nprocs=world_size)
else:
Expand Down

0 comments on commit 7579d4d

Please sign in to comment.