Skip to content

Commit 526b3f3

Browse files
committed
Merge branch 'main' of https://github.com/liopeer/diffusionmodels into main
2 parents 9411bc0 + 0ecf72f commit 526b3f3

File tree

9 files changed

+319
-26
lines changed

9 files changed

+319
-26
lines changed

diffusion_models/models/multicoil.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
from torch import nn
3+
from jaxtyping import Float
4+
from typing import List
5+
from torch import Tensor
6+
7+
class MultiCoilConv2d(nn.Module):
8+
def __init__(self, *args, **kwargs) -> None:
9+
super().__init__()
10+
self.conv2d = nn.Conv2d(*args, **kwargs)
11+
12+
def forward(self, x: Float[Tensor, "batch coils in_channels height width"]) -> Float[Tensor, "batch coils out_channels height width"]:
13+
orig_shape = x.shape
14+
x = self.conv2d(x.view(-1, *orig_shape[-3:]))
15+
return x.view(*orig_shape)
16+
17+
class MultiCoilReducer(nn.Module):
18+
def __init__(self, channel_factors: List[int]=(4, 8, 16, 32), kernel_size: int=3) -> None:
19+
"""Constructor of MultiCoilReducer Class.
20+
21+
This class takes every coil independently (treats them like a sub-fraction of a batch), increases the channel size
22+
massively (from 2 initial channels for complex k-space data) via several convolutional layers and then averages
23+
those channels over the coil dimension. Averaging is invariant to permutations of the input order, so the coil order
24+
or the number of coils will not matter anymore. Inspiration was drawn from point cloud processing, see below.
25+
26+
.. [1] Qi et al., PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation, 2017
27+
28+
Parameters
29+
----------
30+
channel_factors
31+
sequence that includes all factors for channel increases
32+
kernel_size
33+
kernel size for conv layers
34+
"""
35+
super().__init__()
36+
layers = [MultiCoilConv2d(in_channels=2*i, out_channels=2(i+1), kernel_size=kernel_size, padding="same") for i in channel_factors]
37+
38+
def forward(self, x: Float[Tensor, "batch coils 2 height width"]) -> Float[Tensor, "batch out_channels height width"]:
39+
pass

diffusion_models/mri_forward/fft.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from torch.fft import fftn, ifftn, ifftshift, fftshift
2+
from typing import Union
3+
from jaxtyping import Float, Complex
4+
from torch import Tensor
5+
import torch
6+
from utils.helpers import complex_to_2channelfloat
7+
8+
def to_kspace(
9+
x: Union[
10+
Float[Tensor, "*batch 2 height width"],
11+
Complex[Tensor, "*batch height width"]
12+
]
13+
) -> Union[Float[Tensor, "*batch 2 height width"], Complex[Tensor, "*batch height width"]]:
14+
if torch.is_complex(x):
15+
x = fftn(x, dim=(-2,-1))
16+
return fftshift(x, dim=(-2,-1))
17+
else:
18+
x = torch.view_as_complex(x.permute(0,2,3,1))
19+
x = fftn(x, dim=(-2,-1))
20+
x = fftshift(x, dim=(-2,-1))
21+
return complex_to_2channelfloat(x)
22+
23+
def to_imgspace(
24+
x: Union[
25+
Float[Tensor, "*batch 2 height width"],
26+
Complex[Tensor, "*batch height width"]
27+
]
28+
) -> Union[Float[Tensor, "*batch 2 height width"], Complex[Tensor, "*batch height width"]]:
29+
if torch.is_complex(x):
30+
x = ifftn(x, dim=(-2,-1))
31+
return ifftshift(x, dim=(-2,-1))
32+
else:
33+
x = torch.view_as_complex(x.permute(0,2,3,1))
34+
x = ifftn(x, dim=(-2,-1))
35+
x = ifftshift(x, dim=(-2,-1))
36+
return complex_to_2channelfloat(x)

diffusion_models/utils/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __getitem__(self, index) -> Any:
5353
index = self.imgs[index]["index"]
5454
file = h5py.File(file_name, 'r')
5555
x = file["reconstruction_rss"][index]
56-
x = self.transform(x)
56+
x = self.transform(np.array(x))
5757
file.close()
5858
x = x - x.min()
5959
x = x * (1 / x.max())

diffusion_models/utils/helpers.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from jaxtyping import Float, Complex
2+
from torch import Tensor
3+
import torch
4+
15
class dotdict(dict):
26
"""dot.notation access to dictionary attributes"""
37
__getattr__ = dict.__getitem__
@@ -8,4 +12,8 @@ def bytes_to_gb(bytes: int):
812
kb = bytes / 1024
913
mb = kb / 1024
1014
gb = mb / 1024
11-
return gb
15+
return gb
16+
17+
def complex_to_2channelfloat(x: Complex[Tensor, "*batch height width"]) -> Float[Tensor, "*batch 2 height width"]:
18+
x = torch.view_as_real(x)
19+
return x.permute(*[i for i in range(x.dim()-2)],-1,-3,-2)

tests/fastmri_discovery.ipynb

Lines changed: 37 additions & 0 deletions
Large diffs are not rendered by default.

tests/job.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#SBATCH --account=student
33
#SBATCH --output=log/%j.out
44
#SBATCH --error=log/%j.err
5-
#SBATCH --gres=gpu:4
5+
#SBATCH --gres=gpu:2
66
#SBATCH --mem=64G
77
#SBATCH --job-name=mnist_double
88
#SBATCH --constraint='titan_xp|geforce_gtx_titan_x'

tests/k_space_discovery.ipynb

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import context\n",
10+
"from utils.datasets import QuarterFastMRI, MNISTTrainDataset, FastMRIBrainTrain\n",
11+
"from torch.utils.data import DataLoader\n",
12+
"import matplotlib.pyplot as plt\n",
13+
"import torch\n",
14+
"import os\n",
15+
"import h5py\n",
16+
"import torchvision\n",
17+
"from torchvision.transforms import Normalize\n",
18+
"from torchvision.io import read_image\n",
19+
"from torchvision.utils import make_grid\n",
20+
"import numpy as np\n",
21+
"from utils.helpers import complex_to_2channelfloat"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": 2,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"root = \"/itet-stor/peerli/bmicdatasets-originals/Originals/fastMRI/brain/multicoil_train\"\n",
31+
"\n",
32+
"h5_files = [os.path.join(root, elem) for elem in sorted(os.listdir(root))]\n",
33+
"imgs = []\n",
34+
"for file_name in h5_files:\n",
35+
" file = h5py.File(file_name, 'r')\n",
36+
" slices = file[\"reconstruction_rss\"].shape[0]\n",
37+
" for i in range(slices):\n",
38+
" imgs.append({\"file_name\":file_name, \"index\":i})"
39+
]
40+
},
41+
{
42+
"cell_type": "code",
43+
"execution_count": 22,
44+
"metadata": {},
45+
"outputs": [],
46+
"source": [
47+
"index = 100\n",
48+
"\n",
49+
"file_name = imgs[index][\"file_name\"]\n",
50+
"index = imgs[index][\"index\"]\n",
51+
"file = h5py.File(file_name, 'r')\n",
52+
"img = torch.tensor(np.array(file[\"kspace\"]))\n",
53+
"img = complex_to_2channelfloat(img)\n",
54+
"file.close()"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": 32,
60+
"metadata": {},
61+
"outputs": [
62+
{
63+
"data": {
64+
"text/plain": [
65+
"torch.Size([10, 20, 40, 30])"
66+
]
67+
},
68+
"execution_count": 32,
69+
"metadata": {},
70+
"output_type": "execute_result"
71+
}
72+
],
73+
"source": [
74+
"x = torch.randn(10,20,30,40)\n",
75+
"x.permute(*[i for i in range(x.dim()-2)],-1,-2).shape"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": 7,
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"test = torch.zeros(16, 20, 2, 640, 320)\n",
85+
"for i in range(16):\n",
86+
" test[i] = i"
87+
]
88+
},
89+
{
90+
"cell_type": "code",
91+
"execution_count": 15,
92+
"metadata": {},
93+
"outputs": [],
94+
"source": [
95+
"orig_shape = test.shape\n",
96+
"test = test.view(-1, *orig_shape[-3:])\n",
97+
"test = test.view(*orig_shape)"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": 17,
103+
"metadata": {},
104+
"outputs": [
105+
{
106+
"name": "stdout",
107+
"output_type": "stream",
108+
"text": [
109+
"tensor([])\n",
110+
"tensor([])\n",
111+
"tensor([])\n",
112+
"tensor([])\n",
113+
"tensor([])\n",
114+
"tensor([])\n",
115+
"tensor([])\n",
116+
"tensor([])\n",
117+
"tensor([])\n",
118+
"tensor([])\n",
119+
"tensor([])\n",
120+
"tensor([])\n",
121+
"tensor([])\n",
122+
"tensor([])\n",
123+
"tensor([])\n",
124+
"tensor([])\n"
125+
]
126+
}
127+
],
128+
"source": [
129+
"for i in range(16):\n",
130+
" print(test[i][test[i]!=i])"
131+
]
132+
},
133+
{
134+
"cell_type": "code",
135+
"execution_count": null,
136+
"metadata": {},
137+
"outputs": [],
138+
"source": []
139+
}
140+
],
141+
"metadata": {
142+
"kernelspec": {
143+
"display_name": "liotorch",
144+
"language": "python",
145+
"name": "python3"
146+
},
147+
"language_info": {
148+
"codemirror_mode": {
149+
"name": "ipython",
150+
"version": 3
151+
},
152+
"file_extension": ".py",
153+
"mimetype": "text/x-python",
154+
"name": "python",
155+
"nbconvert_exporter": "python",
156+
"pygments_lexer": "ipython3",
157+
"version": "3.11.5"
158+
}
159+
},
160+
"nbformat": 4,
161+
"nbformat_minor": 2
162+
}

tests/train_generative.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
cosine_ann_T_0 = 3,
3535
save_every = 1,
3636
num_samples = 9,
37-
batch_size = 48,
38-
gradient_accumulation_rate = 10,
37+
batch_size = 8,
38+
gradient_accumulation_rate = 64,
3939
learning_rate = 0.0001,
40-
img_size = 128,
40+
img_size = 256,
4141
device_type = "cuda",
4242
in_channels = 1,
4343
dataset = FastMRIBrainTrain,
@@ -62,7 +62,7 @@
6262
)
6363

6464
def load_train_objs(config):
65-
train_set = config.dataset(config.data_path)
65+
train_set = config.dataset(config.data_path, config.img_size)
6666
model = config.architecture(
6767
backbone = config.backbone(
6868
num_encoding_blocks = config.backbone_enc_depth,

0 commit comments

Comments
 (0)