Skip to content

Commit

Permalink
added fastmri dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Oct 31, 2023
1 parent 11dfb17 commit 9f93d91
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 18 deletions.
22 changes: 22 additions & 0 deletions diffusion_models/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pickle
import numpy as np
import torch
import h5py

class Cifar10Dataset(CIFAR10):
def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None:
Expand All @@ -32,6 +33,27 @@ def __init__(self, root: str, train: bool = True, transform: Callable[..., Any]
class MNISTDebugDataset(MNISTTrainDataset):
__len__ = lambda x: 100

class FastMRIBrainTrain(Dataset):
def __init__(self, root: str, size: int) -> None:
super().__init__()
h5_files = [os.path.join(root, elem) for elem in sorted(os.listdir(root))]
self.imgs = []
for file_name in h5_files:
file = h5py.File(file_name, 'r')
slices = file["reconstruction_rss"].shape[0]
for i in range(slices):
self.imgs.append({"file_name":file_name, "index":i})
self.transform = Compose([Normalize((4.8358e-05, ), (np.sqrt(2.4383e-09), )), Resize((size, size), antialias=True)])

def __len__(self):
return len(self.imgs)

def __getitem__(self, index) -> Any:
file_name = self.imgs[index]["file_name"]
index = self.imgs[index]["index"]
file = h5py.File(file_name, 'r')
return self.transform(torch.tensor(file["reconstruction_rss"][index]).unsqueeze(0))

class ImageNet64Dataset(Dataset):
def __init__(self, root: str) -> None:
super().__init__()
Expand Down
66 changes: 66 additions & 0 deletions tests/fastmri_discovery.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import context\n",
"from utils.datasets import FastMRIBrainTrain, MNISTTrainDataset\n",
"from torch.utils.data import DataLoader\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"from torchvision.transforms import Normalize"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(-0.0132) tensor(0.9088)\n"
]
}
],
"source": [
"ds = FastMRIBrainTrain(\"/itet-stor/peerli/bmicdatasets-originals/Originals/fastMRI/brain/multicoil_train\", size=256)\n",
"dl = DataLoader(ds, batch_size=50)\n",
"x = next(iter(dl))\n",
"print(torch.mean(x.view(-1)), torch.var(x.view(-1)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "liotorch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
12 changes: 6 additions & 6 deletions tests/flexible_foward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,26 @@
activation = nn.SiLU,
time_emb_size = config.time_enc_dim,
init_channels = 128,
attention = False,
attention_heads = 0,
attention_ff_dim = 0
attention = True,
attention_heads = 4,
attention_ff_dim = None
),
fwd_diff = config.forward_diff(
timesteps = config.max_timesteps,
start = config.t_start,
end = config.t_end,
offset = config.offset,
max_beta = config.max_beta,
type = "linear"
type = "cosine"
),
img_size = config.img_size,
time_enc_dim = config.time_enc_dim,
dropout = config.dropout
)

model = model.to("cuda")
model.load_state_dict(torch.load("/itet-stor/peerli/net_scratch/ghoulish-goosebump-9/checkpoint90.pt"))
model.load_state_dict(torch.load("/itet-stor/peerli/net_scratch/fearful-werewolf-11/checkpoint90.pt"))

samples = model.sample(9)
samples = torchvision.utils.make_grid(samples, nrow=int(sqrt(9)))
torchvision.utils.save_image(samples, "/home/peerli/Downloads/sample2.png")
torchvision.utils.save_image(samples, "/home/peerli/Downloads/sample9.png")
2 changes: 1 addition & 1 deletion tests/job.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#SBATCH --account=student
#SBATCH --output=log/%j.out
#SBATCH --error=log/%j.err
#SBATCH --gres=gpu:2
#SBATCH --gres=gpu:4
#SBATCH --mem=32G
#SBATCH --job-name=mnist_double
#SBATCH --constraint='titan_xp|geforce_gtx_titan_x'
Expand Down
22 changes: 11 additions & 11 deletions tests/train_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,27 @@
import torch.nn.functional as F

config = dotdict(
total_epochs = 2,
log_wandb = False,
project = "mnist_gen_trials",
total_epochs = 3000,
log_wandb = True,
project = "cifar_gen_trials",
data_path = "/itet-stor/peerli/net_scratch",
checkpoint_folder = "/itet-stor/peerli/net_scratch/run_name", # append wandb run name to this path
wandb_dir = "/itet-stor/peerli/",
from_checkpoint = "/itet-stor/peerli/net_scratch/ghoulish-goosebump-9/checkpoint30.pt",
wandb_dir = "/itet-stor/peerli/net_scratch",
from_checkpoint = False, #"/itet-stor/peerli/net_scratch/ghoulish-goosebump-9/checkpoint30.pt",
loss_func = F.mse_loss,
save_every = 1,
save_every = 30,
num_samples = 9,
show_denoising_history = False,
show_history_every = 50,
batch_size = 256,
batch_size = 64,
learning_rate = 0.0003,
img_size = 32,
device_type = "cuda",
in_channels = 1,
dataset = MNISTDebugDataset,
in_channels = 3,
dataset = Cifar10Dataset,
architecture = DiffusionModel,
backbone = UNet,
attention = False,
attention = True,
attention_heads = 4,
attention_ff_dim = None,
unet_init_channels = 128,
Expand All @@ -52,7 +52,7 @@
t_end = 0.02,
offset = 0.008,
max_beta = 0.999,
schedule_type = "linear",
schedule_type = "cosine",
time_enc_dim = 128,
optimizer = torch.optim.Adam
)
Expand Down

0 comments on commit 9f93d91

Please sign in to comment.