diff --git a/diffusion_models/utils/datasets.py b/diffusion_models/utils/datasets.py index bd00277..454790a 100644 --- a/diffusion_models/utils/datasets.py +++ b/diffusion_models/utils/datasets.py @@ -7,6 +7,7 @@ import pickle import numpy as np import torch +import h5py class Cifar10Dataset(CIFAR10): def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None: @@ -32,6 +33,27 @@ def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] class MNISTDebugDataset(MNISTTrainDataset): __len__ = lambda x: 100 +class FastMRIBrainTrain(Dataset): + def __init__(self, root: str, size: int) -> None: + super().__init__() + h5_files = [os.path.join(root, elem) for elem in sorted(os.listdir(root))] + self.imgs = [] + for file_name in h5_files: + file = h5py.File(file_name, 'r') + slices = file["reconstruction_rss"].shape[0] + for i in range(slices): + self.imgs.append({"file_name":file_name, "index":i}) + self.transform = Compose([Normalize((4.8358e-05, ), (np.sqrt(2.4383e-09), )), Resize((size, size), antialias=True)]) + + def __len__(self): + return len(self.imgs) + + def __getitem__(self, index) -> Any: + file_name = self.imgs[index]["file_name"] + index = self.imgs[index]["index"] + file = h5py.File(file_name, 'r') + return self.transform(torch.tensor(file["reconstruction_rss"][index]).unsqueeze(0)) + class ImageNet64Dataset(Dataset): def __init__(self, root: str) -> None: super().__init__() diff --git a/tests/fastmri_discovery.ipynb b/tests/fastmri_discovery.ipynb new file mode 100644 index 0000000..17e9af6 --- /dev/null +++ b/tests/fastmri_discovery.ipynb @@ -0,0 +1,66 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import context\n", + "from utils.datasets import FastMRIBrainTrain, MNISTTrainDataset\n", + "from torch.utils.data import DataLoader\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "from torchvision.transforms import Normalize" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-0.0132) tensor(0.9088)\n" + ] + } + ], + "source": [ + "ds = FastMRIBrainTrain(\"/itet-stor/peerli/bmicdatasets-originals/Originals/fastMRI/brain/multicoil_train\", size=256)\n", + "dl = DataLoader(ds, batch_size=50)\n", + "x = next(iter(dl))\n", + "print(torch.mean(x.view(-1)), torch.var(x.view(-1)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "liotorch", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/flexible_foward.py b/tests/flexible_foward.py index 8c8a584..49c8232 100644 --- a/tests/flexible_foward.py +++ b/tests/flexible_foward.py @@ -13,9 +13,9 @@ activation = nn.SiLU, time_emb_size = config.time_enc_dim, init_channels = 128, - attention = False, - attention_heads = 0, - attention_ff_dim = 0 + attention = True, + attention_heads = 4, + attention_ff_dim = None ), fwd_diff = config.forward_diff( timesteps = config.max_timesteps, @@ -23,7 +23,7 @@ end = config.t_end, offset = config.offset, max_beta = config.max_beta, - type = "linear" + type = "cosine" ), img_size = config.img_size, time_enc_dim = config.time_enc_dim, @@ -31,8 +31,8 @@ ) model = model.to("cuda") -model.load_state_dict(torch.load("/itet-stor/peerli/net_scratch/ghoulish-goosebump-9/checkpoint90.pt")) +model.load_state_dict(torch.load("/itet-stor/peerli/net_scratch/fearful-werewolf-11/checkpoint90.pt")) samples = model.sample(9) samples = torchvision.utils.make_grid(samples, nrow=int(sqrt(9))) -torchvision.utils.save_image(samples, "/home/peerli/Downloads/sample2.png") \ No newline at end of file +torchvision.utils.save_image(samples, "/home/peerli/Downloads/sample9.png") \ No newline at end of file diff --git a/tests/job.sh b/tests/job.sh index 4cdefe8..85e8704 100644 --- a/tests/job.sh +++ b/tests/job.sh @@ -2,7 +2,7 @@ #SBATCH --account=student #SBATCH --output=log/%j.out #SBATCH --error=log/%j.err -#SBATCH --gres=gpu:2 +#SBATCH --gres=gpu:4 #SBATCH --mem=32G #SBATCH --job-name=mnist_double #SBATCH --constraint='titan_xp|geforce_gtx_titan_x' diff --git a/tests/train_generative.py b/tests/train_generative.py index 0b36764..5aa67c9 100644 --- a/tests/train_generative.py +++ b/tests/train_generative.py @@ -18,27 +18,27 @@ import torch.nn.functional as F config = dotdict( - total_epochs = 2, - log_wandb = False, - project = "mnist_gen_trials", + total_epochs = 3000, + log_wandb = True, + project = "cifar_gen_trials", data_path = "/itet-stor/peerli/net_scratch", checkpoint_folder = "/itet-stor/peerli/net_scratch/run_name", # append wandb run name to this path - wandb_dir = "/itet-stor/peerli/", - from_checkpoint = "/itet-stor/peerli/net_scratch/ghoulish-goosebump-9/checkpoint30.pt", + wandb_dir = "/itet-stor/peerli/net_scratch", + from_checkpoint = False, #"/itet-stor/peerli/net_scratch/ghoulish-goosebump-9/checkpoint30.pt", loss_func = F.mse_loss, - save_every = 1, + save_every = 30, num_samples = 9, show_denoising_history = False, show_history_every = 50, - batch_size = 256, + batch_size = 64, learning_rate = 0.0003, img_size = 32, device_type = "cuda", - in_channels = 1, - dataset = MNISTDebugDataset, + in_channels = 3, + dataset = Cifar10Dataset, architecture = DiffusionModel, backbone = UNet, - attention = False, + attention = True, attention_heads = 4, attention_ff_dim = None, unet_init_channels = 128, @@ -52,7 +52,7 @@ t_end = 0.02, offset = 0.008, max_beta = 0.999, - schedule_type = "linear", + schedule_type = "cosine", time_enc_dim = 128, optimizer = torch.optim.Adam )