Skip to content

Commit

Permalink
Rename linear problem solver and add parameter type checking (#457)
Browse files Browse the repository at this point in the history
* Rename ATADSolver to MatrixATADSolver

* Improve docs

* Add error checking on input types

* Another jaxlb/jax version bump

* Fix black formatting

* Fix short docstring

* Implement __array__ method in MatrixOperator

* Docstring edits

* Bump max jaxlib/jax versions

* Fix error message

---------

Co-authored-by: Michael-T-McCann <[email protected]>
  • Loading branch information
bwohlberg and Michael-T-McCann authored Oct 26, 2023
1 parent ea2aaf0 commit 3973e1b
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 38 deletions.
3 changes: 2 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ SCICO Release Notes
Version 0.0.5 (unreleased)
----------------------------

• No significant changes yet.
• Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.19.



Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ scipy>=1.6.0
tifffile
imageio>=2.17
matplotlib
jaxlib>=0.4.3,<=0.4.16
jax>=0.4.3,<=0.4.16
jaxlib>=0.4.3,<=0.4.19
jax>=0.4.3,<=0.4.19
flax>=0.6.1,<=0.6.9
svmbir>=0.3.3
pyabel>=0.9.0
13 changes: 6 additions & 7 deletions scico/linop/_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import numpy as np

import jax
import jax.numpy as jnp
from jax.dtypes import result_type
from jax.typing import ArrayLike

import scico.numpy as snp

Expand Down Expand Up @@ -65,7 +65,7 @@ def wrapper(a, b):
class MatrixOperator(LinearOperator):
"""Linear operator implementing matrix multiplication."""

def __init__(self, A: snp.Array, input_cols: int = 0):
def __init__(self, A: ArrayLike, input_cols: int = 0):
"""
Args:
A: Dense array. The action of the created
Expand All @@ -80,17 +80,16 @@ def __init__(self, A: snp.Array, input_cols: int = 0):
self.A: snp.Array #: Dense array implementing this matrix

# if A is an ndarray, make sure it gets converted to a jax array
if isinstance(A, jnp.ndarray):
self.A = A
elif isinstance(A, np.ndarray):
self.A = jax.device_put(A) # TODO: ensure_on_device?
else:
if not snp.util.is_arraylike(A):
raise TypeError(f"Expected numpy or jax array, got {type(A)}.")
self.A = jnp.array(A)

# Can only do rank-2 arrays
if A.ndim != 2:
raise TypeError(f"Expected a two-dimensional array, got array of shape {A.shape}.")

self.__array__ = A.__array__ # enables jnp.array(H)

if input_cols == 0:
input_shape = A.shape[1]
output_shape = A.shape[0]
Expand Down
15 changes: 15 additions & 0 deletions scico/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,21 @@ def shape_to_size(shape: Union[Shape, BlockShape]) -> int:
return prod(shape)


def is_arraylike(x: Any) -> bool:
"""Check if input is of type :class:`jax.ArrayLike`.
`isinstance(x, jax.typing.ArrayLike)` does not work in Python < 3.10,
see https://jax.readthedocs.io/en/latest/jax.typing.html#jax-typing-best-practices.
Args:
x: Object to be tested.
Returns:
``True`` if `x` is an ArrayLike, ``False`` otherwise.
"""
return isinstance(x, (np.ndarray, jax.Array)) or np.isscalar(x)


def is_nested(x: Any) -> bool:
"""Check if input is a list/tuple containing at least one list/tuple.
Expand Down
10 changes: 5 additions & 5 deletions scico/optimize/_admmaux.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from scico.loss import SquaredL2Loss
from scico.numpy import Array, BlockArray
from scico.numpy.util import ensure_on_device, is_real_dtype
from scico.solver import ATADSolver, ConvATADSolver
from scico.solver import ConvATADSolver, MatrixATADSolver
from scico.solver import cg as scico_cg
from scico.solver import minimize

Expand Down Expand Up @@ -296,14 +296,14 @@ class MatrixSubproblemSolver(LinearSubproblemSolver):
\mb{u}^{(k)}_i) \;,
which is solved by factorization of the left hand side of the
equation, using :class:`.ATADSolver`.
equation, using :class:`.MatrixATADSolver`.
Attributes:
admm (:class:`.ADMM`): ADMM solver object to which the solver is
attached.
solve_kwargs (dict): Dictionary of arguments for solver
:class:`.ATADSolver` initialization.
:class:`.MatrixATADSolver` initialization.
"""

def __init__(self, check_solve: bool = False, solve_kwargs: Optional[dict[str, Any]] = None):
Expand All @@ -313,7 +313,7 @@ def __init__(self, check_solve: bool = False, solve_kwargs: Optional[dict[str, A
check_solve: If ``True``, compute solver accuracy after each
solve.
solve_kwargs: Dictionary of arguments for solver
:class:`.ATADSolver` initialization.
:class:`.MatrixATADSolver` initialization.
"""
self.check_solve = check_solve
default_solve_kwargs = {"cho_factor": False}
Expand Down Expand Up @@ -352,7 +352,7 @@ def internal_init(self, admm: soa.ADMM):
Csum = reduce(
lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)]
)
self.solver = ATADSolver(A, Csum, W, **self.solve_kwargs)
self.solver = MatrixATADSolver(A, Csum, W, **self.solve_kwargs)

def solve(self, x0: Array) -> Array:
"""Solve the ADMM step.
Expand Down
58 changes: 40 additions & 18 deletions scico/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

import jax
import jax.experimental.host_callback as hcb
import jax.numpy as jnp
import jax.scipy.linalg as jsl

import scico.numpy as snp
Expand Down Expand Up @@ -260,14 +261,13 @@ def fun(x0):

def minimize_scalar(
func: Callable,
bracket: Optional[Union[Sequence[float]]] = None,
bracket: Optional[Sequence[float]] = None,
bounds: Optional[Sequence[float]] = None,
args: Union[Tuple, Tuple[Any]] = (),
method: str = "brent",
tol: Optional[float] = None,
options: Optional[dict] = None,
) -> spopt.OptimizeResult:

"""Minimization of scalar function of one variable.
Wrapper around :func:`scipy.optimize.minimize_scalar`.
Expand Down Expand Up @@ -579,8 +579,8 @@ def golden(
return r


class ATADSolver:
r"""Solver for linear system involving a symmetric product plus a diagonal.
class MatrixATADSolver:
r"""Solver for linear system involving a symmetric product.
Solve a linear system of the form
Expand All @@ -596,12 +596,18 @@ class ATADSolver:
where :math:`A \in \mbb{R}^{M \times N}`,
:math:`W \in \mbb{R}^{M \times M}` and
:math:`D \in \mbb{R}^{N \times N}`. The solution is computed by
factorization of matrix :math:`A^T W A + D` and solution via Gaussian
elimination. If :math:`D` is diagonal and :math:`N < M` (i.e.
:math:`A W A^T` is smaller than :math:`A^T W A`), then
:math:`A W A^T + D` is factorized and the original problem is solved
via the Woodbury matrix identity
:math:`D \in \mbb{R}^{N \times N}`. :math:`A` must be an instance of
:class:`.MatrixOperator` or an array; :math:`D` must be an instance
of :class:`.MatrixOperator`, :class:`.Diagonal`, or an array, and
:math:`W`, if specified, must be an instance of :class:`.Diagonal`
or an array.
The solution is computed by factorization of matrix
:math:`A^T W A + D` and solution via Gaussian elimination. If
:math:`D` is diagonal and :math:`N < M` (i.e. :math:`A W A^T` is
smaller than :math:`A^T W A`), then :math:`A W A^T + D` is factorized
and the original problem is solved via the Woodbury matrix identity
.. math::
Expand Down Expand Up @@ -698,8 +704,12 @@ def __init__(
r"""
Args:
A: Matrix :math:`A`.
D: Matrix :math:`D`.
W: Matrix :math:`W`.
D: Matrix :math:`D`. If a 2D array or :class:`MatrixOperator`,
specifies the 2D matrix :math:`D`. If 1D array or
:class:`Diagonal`, specifies the diagonal elements
of :math:`D`.
W: Matrix :math:`W`. Specifies the diagonal elements of
:math:`W`. Defaults to an array with unit entries.
cho_factor: Flag indicating whether to use Cholesky
(``True``) or LU (``False``) factorization.
lower: Flag indicating whether lower (``True``) or upper
Expand All @@ -708,16 +718,28 @@ def __init__(
check_finite: Flag indicating whether the input array should
be checked for ``Inf`` and ``NaN`` values.
"""
if isinstance(A, MatrixOperator):
A = A.to_array()
if isinstance(D, MatrixOperator):
D = D.to_array()
elif isinstance(D, Diagonal):
A = jnp.array(A)

if isinstance(D, Diagonal):
D = D.diagonal
if not D.ndim == 1:
raise ValueError("If Diagonal, D should have a 1D diagonal.")
else:
D = jnp.array(D)
if not D.ndim in [1, 2]:
raise ValueError("If array or MatrixOperator, D should be 1D or 2D.")

if W is None:
W = snp.ones(A.shape[0], dtype=A.dtype)
elif isinstance(W, Diagonal):
W = W.diagonal
if not W.ndim == 1:
raise ValueError("If Diagonal, W should have a 1D diagonal.")
elif not isinstance(W, Array):
raise TypeError(
f"Operator W is required to be None, a Diagonal, or an array; got a {type(W)}."
)

self.A = A
self.D = D
self.W = W
Expand Down Expand Up @@ -796,7 +818,7 @@ def accuracy(self, x: Array, b: Array) -> float:


class ConvATADSolver:
r"""Solver for sum of convolutions plus diagonal linear system.
r"""Solver for a linear system involving a sum of convolutions.
Solve a linear system of the form
Expand Down
6 changes: 4 additions & 2 deletions scico/test/linop/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def setup_method(self, method):
@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)])
def test_eval(self, matrix_shape, input_dtype, input_cols):

A, key = randn(matrix_shape, dtype=input_dtype, key=self.key)
Ao = MatrixOperator(A, input_cols=input_cols)

Expand All @@ -38,7 +37,6 @@ def test_eval(self, matrix_shape, input_dtype, input_cols):
@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)])
def test_adjoint(self, matrix_shape, input_dtype, input_cols):

A, key = randn(matrix_shape, dtype=input_dtype, key=self.key)
Ao = MatrixOperator(A, input_cols=input_cols)

Expand Down Expand Up @@ -262,6 +260,10 @@ def test_to_array(self):
assert isinstance(A_array, np.ndarray)
np.testing.assert_allclose(A_array, A)

A_array = jnp.array(Ao)
assert isinstance(A_array, jax.Array)
np.testing.assert_allclose(A_array, A)

@pytest.mark.parametrize("ord", ["fro", 2])
@pytest.mark.parametrize("axis", [None, 0, 1])
@pytest.mark.parametrize("keepdims", [True, False])
Expand Down
6 changes: 3 additions & 3 deletions scico/test/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def test_solve_atai(cho_factor, wide, weighted, alpha):
D = alpha * snp.ones((A.shape[1],))
ATAD = A.T @ (Wa * A) + alpha * snp.identity(A.shape[1])
b = ATAD @ x0
slv = solver.ATADSolver(A, D, W=W, cho_factor=cho_factor)
slv = solver.MatrixATADSolver(A, D, W=W, cho_factor=cho_factor)
x1 = slv.solve(b)
assert metric.rel_res(x0, x1) < 5e-5

Expand All @@ -338,7 +338,7 @@ def test_solve_aati(cho_factor, wide, alpha):
D = alpha * snp.ones((A.shape[0],))
AATD = A @ A.T + alpha * snp.identity(A.shape[0])
b = AATD @ x0
slv = solver.ATADSolver(A.T, D)
slv = solver.MatrixATADSolver(A.T, D)
x1 = slv.solve(b)
assert metric.rel_res(x0, x1) < 5e-5

Expand All @@ -365,7 +365,7 @@ def test_solve_atad(cho_factor, wide, vector):
D = snp.abs(D) # only required for Cholesky, but improved accuracy for LU
ATAD = A.T @ A + snp.diag(D)
b = ATAD @ x0
slv = solver.ATADSolver(A, D, cho_factor=cho_factor)
slv = solver.MatrixATADSolver(A, D, cho_factor=cho_factor)
x1 = slv.solve(b)
assert metric.rel_res(x0, x1) < 5e-5
assert slv.accuracy(x1, b) < 5e-5

0 comments on commit 3973e1b

Please sign in to comment.