Skip to content

Commit

Permalink
Switch ML examples to integrated X-ray transform (#562)
Browse files Browse the repository at this point in the history
* Bump maximum jaxlib/jax version and resolve some errors and warnings

* Handle get_backend change in jax 0.4.33

* Update checkpoint interface

* Remove unserializable partial functool from checkpoint

* Update orbax-checkpoint version constraint

* Switch from astra to scico 2D projector

* Renamed MoDL and ODP CT examples

* Minor edits

* Bug fix

* Choose parameters so that astra and scico projectors are equivalent

* Better choice of det_count parameter

* Fix test

* Improve installation instructions

* Improve warning in DnCNN function prox docs

* Typo fix

* Bump copyright year

* Resolve typing error

* Add filtered back projection for 2D projector

* Update change summary

* Docstring fixes

* Resolve errors in jitting method

* Switch to scico projector for CT training data generation

* Rename example

* Rename example

* Add conditional in case of prior ray.init

* Bug fix

* Trivial edit

* Bug fix

* Improve consistency with similar examples

* Bug fix

* Typo fix

* Remove astra import test

* Bug fix

* Some improvements

* Update example index

* Update secondary indices

* Update submodule

* Recent version of orbax-checkpoint not available via conda

* Remove mamba

* Minor change

* Remove mamba

* Fix string syntax

* Remove code that breaks notebook generation script

* Bump jaxlib/jax max version

* Fix script

* Change default matplotlib backend selection

* Update submodule

* Clean up

* Improve tests

* Improve mask mechanism

* Improve docs

* Minor docs edit

* Update submodule

* Typing fixes

* Avoid mismatch with declared linop dtype

* Address review comment

* Another search/replace error fix

* Update submodule

---------

Co-authored-by: crstngc <[email protected]>
Co-authored-by: Brendt Wohlberg <[email protected]>
  • Loading branch information
3 people authored Oct 23, 2024
1 parent 0dd2f57 commit 31c20a3
Show file tree
Hide file tree
Showing 21 changed files with 93 additions and 109 deletions.
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Version 0.0.6 (unreleased)
• Rename ``scico.flax.save_weights`` and ``scico.flax.load_weights`` to
``scico.flax.save_variables`` and ``scico.flax.load_variables``
respectively.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.33.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.34.
• Support ``flax`` versions 0.8.0 to 0.9.0.


Expand Down
17 changes: 8 additions & 9 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@ Computed Tomography
examples/ct_svmbir_ppp_bm3d_admm_cg
examples/ct_svmbir_ppp_bm3d_admm_prox
examples/ct_fan_svmbir_ppp_bm3d_admm_prox
examples/ct_astra_modl_train_foam2
examples/ct_astra_odp_train_foam2
examples/ct_astra_unet_train_foam2
examples/ct_modl_train_foam2
examples/ct_odp_train_foam2
examples/ct_unet_train_foam2
examples/ct_projector_comparison_2d
examples/ct_projector_comparison_3d
examples/ct_multi_cs_tv_admm
examples/ct_multi_tv_admm

Deconvolution
Expand Down Expand Up @@ -96,7 +95,7 @@ Miscellaneous
examples/denoise_dncnn_universal
examples/diffusercam_tv_admm
examples/video_rpca_admm
examples/ct_astra_datagen_foam2
examples/ct_datagen_foam2
examples/deconv_datagen_bsds
examples/deconv_datagen_foam1
examples/denoise_datagen_bsds
Expand Down Expand Up @@ -181,10 +180,10 @@ Machine Learning
.. toctree::
:maxdepth: 1

examples/ct_astra_datagen_foam2
examples/ct_astra_modl_train_foam2
examples/ct_astra_odp_train_foam2
examples/ct_astra_unet_train_foam2
examples/ct_datagen_foam2
examples/ct_modl_train_foam2
examples/ct_odp_train_foam2
examples/ct_unet_train_foam2
examples/deconv_datagen_bsds
examples/deconv_datagen_foam1
examples/deconv_modl_train_foam1
Expand Down
4 changes: 2 additions & 2 deletions examples/jnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def py_file_to_string(src):

# Process remainder of source file
for line in srcfile:
if re.match("^input\(", line): # end processing when input statement encountered
if re.match(r"^input\(", line): # end processing when input statement encountered
break
line = re.sub('^r"""', '"""', line) # remove r from r"""
line = re.sub(":cite:\`([^`]+)\`", r'<cite data-cite="\1"/>', line) # fix cite format
line = re.sub(r":cite:\`([^`]+)\`", r'<cite data-cite="\1"/>', line) # fix cite format
lines.append(line)

# Backtrack through list of lines to remove trailing newlines
Expand Down
24 changes: 12 additions & 12 deletions examples/scripts/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ Computed Tomography
PPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox)
`ct_fan_svmbir_ppp_bm3d_admm_prox.py <ct_fan_svmbir_ppp_bm3d_admm_prox.py>`_
PPP (with BM3D) Fan-Beam CT Reconstruction
`ct_astra_modl_train_foam2.py <ct_astra_modl_train_foam2.py>`_
CT Training and Reconstructions with MoDL
`ct_astra_odp_train_foam2.py <ct_astra_odp_train_foam2.py>`_
CT Training and Reconstructions with ODP
`ct_astra_unet_train_foam2.py <ct_astra_unet_train_foam2.py>`_
`ct_modl_train_foam2.py <ct_modl_train_foam2.py>`_
CT Training and Reconstruction with MoDL
`ct_odp_train_foam2.py <ct_odp_train_foam2.py>`_
CT Training and Reconstruction with ODP
`ct_unet_train_foam2.py <ct_unet_train_foam2.py>`_
CT Training and Reconstructions with UNet
`ct_projector_comparison_2d.py <ct_projector_comparison_2d.py>`_
2D X-ray Transform Comparison
Expand Down Expand Up @@ -123,7 +123,7 @@ Miscellaneous
TV-Regularized 3D DiffuserCam Reconstruction
`video_rpca_admm.py <video_rpca_admm.py>`_
Video Decomposition via Robust PCA
`ct_astra_datagen_foam2.py <ct_astra_datagen_foam2.py>`_
`ct_datagen_foam2.py <ct_datagen_foam2.py>`_
CT Data Generation for NN Training
`deconv_datagen_bsds.py <deconv_datagen_bsds.py>`_
Blurred Data Generation (Natural Images) for NN Training
Expand Down Expand Up @@ -239,13 +239,13 @@ Sparsity
Machine Learning
^^^^^^^^^^^^^^^^

`ct_astra_datagen_foam2.py <ct_astra_datagen_foam2.py>`_
`ct_datagen_foam2.py <ct_datagen_foam2.py>`_
CT Data Generation for NN Training
`ct_astra_modl_train_foam2.py <ct_astra_modl_train_foam2.py>`_
CT Training and Reconstructions with MoDL
`ct_astra_odp_train_foam2.py <ct_astra_odp_train_foam2.py>`_
CT Training and Reconstructions with ODP
`ct_astra_unet_train_foam2.py <ct_astra_unet_train_foam2.py>`_
`ct_modl_train_foam2.py <ct_modl_train_foam2.py>`_
CT Training and Reconstruction with MoDL
`ct_odp_train_foam2.py <ct_odp_train_foam2.py>`_
CT Training and Reconstruction with ODP
`ct_unet_train_foam2.py <ct_unet_train_foam2.py>`_
CT Training and Reconstructions with UNet
`deconv_datagen_bsds.py <deconv_datagen_bsds.py>`_
Blurred Data Generation (Natural Images) for NN Training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@
"""

# isort: off
import os
import numpy as np

import logging
import ray

ray.init(logging_level=logging.ERROR) # need to call init before jax import: ray-project/ray#44087

# Set an arbitrary processor count (only applies if GPU is not available).
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

from scico import plot
from scico.flax.examples import load_ct_data

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
# with the package.

r"""
CT Training and Reconstructions with MoDL
=========================================
CT Training and Reconstruction with MoDL
========================================
This example demonstrates the training and application of a
model-based deep learning (MoDL) architecture described in
Expand Down Expand Up @@ -65,7 +65,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.xray.astra import XRayTransform2D
from scico.linop.xray import XRayTransform2D

"""
Prepare parallel processing. Set an arbitrary processor count (only
Expand All @@ -89,16 +89,17 @@


"""
Build CT projection operator.
Build CT projection operator. Parameters are chosen so that the operator
is equivalent to the one used to generate the training data.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = XRayTransform2D(
input_shape=(N, N),
det_spacing=1,
det_count=N,
angles=angles,
) # CT projection operator
A = (1.0 / N) * A # normalized
det_count=int(N * 1.05 / np.sqrt(2.0)),
dx=1.0 / np.sqrt(2),
)
A = (1.0 / N) * A # normalize projection operator


"""
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ct_multi_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
np.random.seed(1234)
x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N))

det_count = N
det_count = int(N * 1.05 / np.sqrt(2.0))
det_spacing = np.sqrt(2)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
# with the package.

r"""
CT Training and Reconstructions with ODP
========================================
CT Training and Reconstruction with ODP
=======================================
This example demonstrates the training of the unrolled optimization with
deep priors (ODP) gradient descent architecture described in
Expand Down Expand Up @@ -72,7 +72,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.xray.astra import XRayTransform2D
from scico.linop.xray import XRayTransform2D


platform = get_backend().platform
Expand All @@ -92,21 +92,22 @@


"""
Build CT projection operator.
Build CT projection operator. Parameters are chosen so that the operator
is equivalent to the one used to generate the training data.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = XRayTransform2D(
input_shape=(N, N),
det_spacing=1,
det_count=N,
angles=angles,
) # CT projection operator
A = (1.0 / N) * A # normalized
det_count=int(N * 1.05 / np.sqrt(2.0)),
dx=1.0 / np.sqrt(2),
)
A = (1.0 / N) * A # normalize projection operator


"""
Build training and testing structures. Inputs are the sinograms and
outpus are the original generated foams. Keep training and testing
outputs are the original generated foams. Keep training and testing
partitions.
"""
numtr = 320
Expand Down
8 changes: 3 additions & 5 deletions examples/scripts/ct_projector_comparison_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
Create a ground truth image.
"""
N = 512

det_count = int(jnp.ceil(jnp.sqrt(2 * N**2)))

x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = jnp.array(x_gt)

Expand All @@ -41,17 +38,18 @@
"""
num_angles = 500
angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)
det_count = int(N * 1.02 / jnp.sqrt(2.0))

timer = Timer()

projectors = {}
timer.start("scico_init")
projectors["scico"] = XRayTransform2D((N, N), angles)
projectors["scico"] = XRayTransform2D((N, N), angles, det_count=det_count)
timer.stop("scico_init")

timer.start("astra_init")
projectors["astra"] = astra.XRayTransform2D(
(N, N), det_count=det_count, det_spacing=1.0, angles=angles - jnp.pi / 2.0
(N, N), det_count=det_count, det_spacing=np.sqrt(2), angles=angles - jnp.pi / 2.0
)
timer.stop("astra_init")

Expand Down
File renamed without changes.
16 changes: 8 additions & 8 deletions examples/scripts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ Computed Tomography
- ct_svmbir_ppp_bm3d_admm_cg.py
- ct_svmbir_ppp_bm3d_admm_prox.py
- ct_fan_svmbir_ppp_bm3d_admm_prox.py
- ct_astra_modl_train_foam2.py
- ct_astra_odp_train_foam2.py
- ct_astra_unet_train_foam2.py
- ct_modl_train_foam2.py
- ct_odp_train_foam2.py
- ct_unet_train_foam2.py
- ct_projector_comparison_2d.py
- ct_projector_comparison_3d.py
- ct_multi_tv_admm.py
Expand Down Expand Up @@ -73,7 +73,7 @@ Miscellaneous
- denoise_dncnn_universal.py
- diffusercam_tv_admm.py
- video_rpca_admm.py
- ct_astra_datagen_foam2.py
- ct_datagen_foam2.py
- deconv_datagen_bsds.py
- deconv_datagen_foam1.py
- denoise_datagen_bsds.py
Expand Down Expand Up @@ -143,10 +143,10 @@ Sparsity
Machine Learning
^^^^^^^^^^^^^^^^

- ct_astra_datagen_foam2.py
- ct_astra_modl_train_foam2.py
- ct_astra_odp_train_foam2.py
- ct_astra_unet_train_foam2.py
- ct_datagen_foam2.py
- ct_modl_train_foam2.py
- ct_odp_train_foam2.py
- ct_unet_train_foam2.py
- deconv_datagen_bsds.py
- deconv_datagen_foam1.py
- deconv_modl_train_foam1.py
Expand Down
1 change: 0 additions & 1 deletion misc/conda/install_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ rm -f /tmp/miniconda.sh
export PATH="$CONDAHOME/bin:$PATH"
hash -r
conda config --set always_yes yes
conda install mamba -n base -c conda-forge
conda update -q conda
conda info -a

Expand Down
11 changes: 4 additions & 7 deletions misc/conda/make_conda_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ EOF
)
# Requirements that cannot be installed via conda (i.e. have to use pip)
NOCONDA=$(cat <<-EOF
flax bm3d bm4d py2jn colour_demosaicing hyperopt ray[tune,train]
flax orbax-checkpoint bm3d bm4d py2jn colour_demosaicing hyperopt ray[tune,train]
EOF
)

Expand Down Expand Up @@ -217,19 +217,16 @@ eval "$(conda shell.bash hook)" # required to avoid errors re: `conda init`
conda activate $ENVNM # Q: why not `source activate`? A: not always in the path

# Add conda-forge channel
conda config --env --append channels conda-forge

# Install mamba
conda install mamba -n base -c conda-forge
conda config --append channels conda-forge

# Install required conda packages (and extra useful packages)
mamba install $CONDA_FLAGS $CONDAREQ ipython
conda install $CONDA_FLAGS $CONDAREQ ipython

# Utility ffmpeg is required by imageio for reading mp4 video files
# it can also be installed via the system package manager, .e.g.
# sudo apt install ffmpeg
if [ "$(which ffmpeg)" = '' ]; then
mamba install $CONDA_FLAGS ffmpeg
conda install $CONDA_FLAGS ffmpeg
fi

# Install jaxlib and jax
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ scipy>=1.6.0
imageio>=2.17
tifffile
matplotlib
jaxlib>=0.4.3,<=0.4.33
jax>=0.4.3,<=0.4.33
jaxlib>=0.4.3,<=0.4.34
jax>=0.4.3,<=0.4.34
orbax-checkpoint>=0.5.0
flax>=0.8.0,<=0.9.0
pyabel>=0.9.0
Loading

0 comments on commit 31c20a3

Please sign in to comment.