Skip to content

Commit

Permalink
multicoil trials
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Nov 10, 2023
1 parent e4b29da commit 0a73be4
Show file tree
Hide file tree
Showing 9 changed files with 319 additions and 26 deletions.
39 changes: 39 additions & 0 deletions diffusion_models/models/multicoil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from torch import nn
from jaxtyping import Float
from typing import List
from torch import Tensor

class MultiCoilConv2d(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__()
self.conv2d = nn.Conv2d(*args, **kwargs)

def forward(self, x: Float[Tensor, "batch coils in_channels height width"]) -> Float[Tensor, "batch coils out_channels height width"]:
orig_shape = x.shape
x = self.conv2d(x.view(-1, *orig_shape[-3:]))
return x.view(*orig_shape)

class MultiCoilReducer(nn.Module):
def __init__(self, channel_factors: List[int]=(4, 8, 16, 32), kernel_size: int=3) -> None:
"""Constructor of MultiCoilReducer Class.
This class takes every coil independently (treats them like a sub-fraction of a batch), increases the channel size
massively (from 2 initial channels for complex k-space data) via several convolutional layers and then averages
those channels over the coil dimension. Averaging is invariant to permutations of the input order, so the coil order
or the number of coils will not matter anymore. Inspiration was drawn from point cloud processing, see below.
.. [1] Qi et al., PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation, 2017
Parameters
----------
channel_factors
sequence that includes all factors for channel increases
kernel_size
kernel size for conv layers
"""
super().__init__()
layers = [MultiCoilConv2d(in_channels=2*i, out_channels=2(i+1), kernel_size=kernel_size, padding="same") for i in channel_factors]

def forward(self, x: Float[Tensor, "batch coils 2 height width"]) -> Float[Tensor, "batch out_channels height width"]:
pass
36 changes: 36 additions & 0 deletions diffusion_models/mri_forward/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from torch.fft import fftn, ifftn, ifftshift, fftshift
from typing import Union
from jaxtyping import Float, Complex
from torch import Tensor
import torch
from utils.helpers import complex_to_2channelfloat

def to_kspace(
x: Union[
Float[Tensor, "*batch 2 height width"],
Complex[Tensor, "*batch height width"]
]
) -> Union[Float[Tensor, "*batch 2 height width"], Complex[Tensor, "*batch height width"]]:
if torch.is_complex(x):
x = fftn(x, dim=(-2,-1))
return fftshift(x, dim=(-2,-1))
else:
x = torch.view_as_complex(x.permute(0,2,3,1))
x = fftn(x, dim=(-2,-1))
x = fftshift(x, dim=(-2,-1))
return complex_to_2channelfloat(x)

def to_imgspace(
x: Union[
Float[Tensor, "*batch 2 height width"],
Complex[Tensor, "*batch height width"]
]
) -> Union[Float[Tensor, "*batch 2 height width"], Complex[Tensor, "*batch height width"]]:
if torch.is_complex(x):
x = ifftn(x, dim=(-2,-1))
return ifftshift(x, dim=(-2,-1))
else:
x = torch.view_as_complex(x.permute(0,2,3,1))
x = ifftn(x, dim=(-2,-1))
x = ifftshift(x, dim=(-2,-1))
return complex_to_2channelfloat(x)
2 changes: 1 addition & 1 deletion diffusion_models/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __getitem__(self, index) -> Any:
index = self.imgs[index]["index"]
file = h5py.File(file_name, 'r')
x = file["reconstruction_rss"][index]
x = self.transform(x)
x = self.transform(np.array(x))
file.close()
x = x - x.min()
x = x * (1 / x.max())
Expand Down
10 changes: 9 additions & 1 deletion diffusion_models/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from jaxtyping import Float, Complex
from torch import Tensor
import torch

class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.__getitem__
Expand All @@ -8,4 +12,8 @@ def bytes_to_gb(bytes: int):
kb = bytes / 1024
mb = kb / 1024
gb = mb / 1024
return gb
return gb

def complex_to_2channelfloat(x: Complex[Tensor, "*batch height width"]) -> Float[Tensor, "*batch 2 height width"]:
x = torch.view_as_real(x)
return x.permute(*[i for i in range(x.dim()-2)],-1,-3,-2)
37 changes: 37 additions & 0 deletions tests/fastmri_discovery.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/job.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#SBATCH --account=student
#SBATCH --output=log/%j.out
#SBATCH --error=log/%j.err
#SBATCH --gres=gpu:4
#SBATCH --gres=gpu:2
#SBATCH --mem=64G
#SBATCH --job-name=mnist_double
#SBATCH --constraint='titan_xp|geforce_gtx_titan_x'
Expand Down
162 changes: 162 additions & 0 deletions tests/k_space_discovery.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import context\n",
"from utils.datasets import QuarterFastMRI, MNISTTrainDataset, FastMRIBrainTrain\n",
"from torch.utils.data import DataLoader\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import os\n",
"import h5py\n",
"import torchvision\n",
"from torchvision.transforms import Normalize\n",
"from torchvision.io import read_image\n",
"from torchvision.utils import make_grid\n",
"import numpy as np\n",
"from utils.helpers import complex_to_2channelfloat"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"root = \"/itet-stor/peerli/bmicdatasets-originals/Originals/fastMRI/brain/multicoil_train\"\n",
"\n",
"h5_files = [os.path.join(root, elem) for elem in sorted(os.listdir(root))]\n",
"imgs = []\n",
"for file_name in h5_files:\n",
" file = h5py.File(file_name, 'r')\n",
" slices = file[\"reconstruction_rss\"].shape[0]\n",
" for i in range(slices):\n",
" imgs.append({\"file_name\":file_name, \"index\":i})"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"index = 100\n",
"\n",
"file_name = imgs[index][\"file_name\"]\n",
"index = imgs[index][\"index\"]\n",
"file = h5py.File(file_name, 'r')\n",
"img = torch.tensor(np.array(file[\"kspace\"]))\n",
"img = complex_to_2channelfloat(img)\n",
"file.close()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([10, 20, 40, 30])"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.randn(10,20,30,40)\n",
"x.permute(*[i for i in range(x.dim()-2)],-1,-2).shape"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"test = torch.zeros(16, 20, 2, 640, 320)\n",
"for i in range(16):\n",
" test[i] = i"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"orig_shape = test.shape\n",
"test = test.view(-1, *orig_shape[-3:])\n",
"test = test.view(*orig_shape)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([])\n",
"tensor([])\n",
"tensor([])\n",
"tensor([])\n",
"tensor([])\n",
"tensor([])\n",
"tensor([])\n",
"tensor([])\n",
"tensor([])\n",
"tensor([])\n",
"tensor([])\n",
"tensor([])\n",
"tensor([])\n",
"tensor([])\n",
"tensor([])\n",
"tensor([])\n"
]
}
],
"source": [
"for i in range(16):\n",
" print(test[i][test[i]!=i])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "liotorch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
8 changes: 4 additions & 4 deletions tests/train_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
cosine_ann_T_0 = 3,
save_every = 1,
num_samples = 9,
batch_size = 48,
gradient_accumulation_rate = 10,
batch_size = 8,
gradient_accumulation_rate = 64,
learning_rate = 0.0001,
img_size = 128,
img_size = 256,
device_type = "cuda",
in_channels = 1,
dataset = FastMRIBrainTrain,
Expand All @@ -62,7 +62,7 @@
)

def load_train_objs(config):
train_set = config.dataset(config.data_path)
train_set = config.dataset(config.data_path, config.img_size)
model = config.architecture(
backbone = config.backbone(
num_encoding_blocks = config.backbone_enc_depth,
Expand Down
Loading

0 comments on commit 0a73be4

Please sign in to comment.