-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
319 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
from torch import nn | ||
from jaxtyping import Float | ||
from typing import List | ||
from torch import Tensor | ||
|
||
class MultiCoilConv2d(nn.Module): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__() | ||
self.conv2d = nn.Conv2d(*args, **kwargs) | ||
|
||
def forward(self, x: Float[Tensor, "batch coils in_channels height width"]) -> Float[Tensor, "batch coils out_channels height width"]: | ||
orig_shape = x.shape | ||
x = self.conv2d(x.view(-1, *orig_shape[-3:])) | ||
return x.view(*orig_shape) | ||
|
||
class MultiCoilReducer(nn.Module): | ||
def __init__(self, channel_factors: List[int]=(4, 8, 16, 32), kernel_size: int=3) -> None: | ||
"""Constructor of MultiCoilReducer Class. | ||
This class takes every coil independently (treats them like a sub-fraction of a batch), increases the channel size | ||
massively (from 2 initial channels for complex k-space data) via several convolutional layers and then averages | ||
those channels over the coil dimension. Averaging is invariant to permutations of the input order, so the coil order | ||
or the number of coils will not matter anymore. Inspiration was drawn from point cloud processing, see below. | ||
.. [1] Qi et al., PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation, 2017 | ||
Parameters | ||
---------- | ||
channel_factors | ||
sequence that includes all factors for channel increases | ||
kernel_size | ||
kernel size for conv layers | ||
""" | ||
super().__init__() | ||
layers = [MultiCoilConv2d(in_channels=2*i, out_channels=2(i+1), kernel_size=kernel_size, padding="same") for i in channel_factors] | ||
|
||
def forward(self, x: Float[Tensor, "batch coils 2 height width"]) -> Float[Tensor, "batch out_channels height width"]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from torch.fft import fftn, ifftn, ifftshift, fftshift | ||
from typing import Union | ||
from jaxtyping import Float, Complex | ||
from torch import Tensor | ||
import torch | ||
from utils.helpers import complex_to_2channelfloat | ||
|
||
def to_kspace( | ||
x: Union[ | ||
Float[Tensor, "*batch 2 height width"], | ||
Complex[Tensor, "*batch height width"] | ||
] | ||
) -> Union[Float[Tensor, "*batch 2 height width"], Complex[Tensor, "*batch height width"]]: | ||
if torch.is_complex(x): | ||
x = fftn(x, dim=(-2,-1)) | ||
return fftshift(x, dim=(-2,-1)) | ||
else: | ||
x = torch.view_as_complex(x.permute(0,2,3,1)) | ||
x = fftn(x, dim=(-2,-1)) | ||
x = fftshift(x, dim=(-2,-1)) | ||
return complex_to_2channelfloat(x) | ||
|
||
def to_imgspace( | ||
x: Union[ | ||
Float[Tensor, "*batch 2 height width"], | ||
Complex[Tensor, "*batch height width"] | ||
] | ||
) -> Union[Float[Tensor, "*batch 2 height width"], Complex[Tensor, "*batch height width"]]: | ||
if torch.is_complex(x): | ||
x = ifftn(x, dim=(-2,-1)) | ||
return ifftshift(x, dim=(-2,-1)) | ||
else: | ||
x = torch.view_as_complex(x.permute(0,2,3,1)) | ||
x = ifftn(x, dim=(-2,-1)) | ||
x = ifftshift(x, dim=(-2,-1)) | ||
return complex_to_2channelfloat(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import context\n", | ||
"from utils.datasets import QuarterFastMRI, MNISTTrainDataset, FastMRIBrainTrain\n", | ||
"from torch.utils.data import DataLoader\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import torch\n", | ||
"import os\n", | ||
"import h5py\n", | ||
"import torchvision\n", | ||
"from torchvision.transforms import Normalize\n", | ||
"from torchvision.io import read_image\n", | ||
"from torchvision.utils import make_grid\n", | ||
"import numpy as np\n", | ||
"from utils.helpers import complex_to_2channelfloat" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"root = \"/itet-stor/peerli/bmicdatasets-originals/Originals/fastMRI/brain/multicoil_train\"\n", | ||
"\n", | ||
"h5_files = [os.path.join(root, elem) for elem in sorted(os.listdir(root))]\n", | ||
"imgs = []\n", | ||
"for file_name in h5_files:\n", | ||
" file = h5py.File(file_name, 'r')\n", | ||
" slices = file[\"reconstruction_rss\"].shape[0]\n", | ||
" for i in range(slices):\n", | ||
" imgs.append({\"file_name\":file_name, \"index\":i})" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 22, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"index = 100\n", | ||
"\n", | ||
"file_name = imgs[index][\"file_name\"]\n", | ||
"index = imgs[index][\"index\"]\n", | ||
"file = h5py.File(file_name, 'r')\n", | ||
"img = torch.tensor(np.array(file[\"kspace\"]))\n", | ||
"img = complex_to_2channelfloat(img)\n", | ||
"file.close()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 32, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"torch.Size([10, 20, 40, 30])" | ||
] | ||
}, | ||
"execution_count": 32, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"x = torch.randn(10,20,30,40)\n", | ||
"x.permute(*[i for i in range(x.dim()-2)],-1,-2).shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"test = torch.zeros(16, 20, 2, 640, 320)\n", | ||
"for i in range(16):\n", | ||
" test[i] = i" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"orig_shape = test.shape\n", | ||
"test = test.view(-1, *orig_shape[-3:])\n", | ||
"test = test.view(*orig_shape)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"tensor([])\n", | ||
"tensor([])\n", | ||
"tensor([])\n", | ||
"tensor([])\n", | ||
"tensor([])\n", | ||
"tensor([])\n", | ||
"tensor([])\n", | ||
"tensor([])\n", | ||
"tensor([])\n", | ||
"tensor([])\n", | ||
"tensor([])\n", | ||
"tensor([])\n", | ||
"tensor([])\n", | ||
"tensor([])\n", | ||
"tensor([])\n", | ||
"tensor([])\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for i in range(16):\n", | ||
" print(test[i][test[i]!=i])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "liotorch", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.