From 3973e1b41fd0910c528e2d17969f1fe65d38b6b6 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 26 Oct 2023 14:16:22 -0600 Subject: [PATCH] Rename linear problem solver and add parameter type checking (#457) * Rename ATADSolver to MatrixATADSolver * Improve docs * Add error checking on input types * Another jaxlb/jax version bump * Fix black formatting * Fix short docstring * Implement __array__ method in MatrixOperator * Docstring edits * Bump max jaxlib/jax versions * Fix error message --------- Co-authored-by: Michael-T-McCann --- CHANGES.rst | 3 +- requirements.txt | 4 +-- scico/linop/_matrix.py | 13 ++++---- scico/numpy/util.py | 15 +++++++++ scico/optimize/_admmaux.py | 10 +++--- scico/solver.py | 58 +++++++++++++++++++++++---------- scico/test/linop/test_matrix.py | 6 ++-- scico/test/test_solver.py | 6 ++-- 8 files changed, 77 insertions(+), 38 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index ebc467893..a413f53cb 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,7 +6,8 @@ SCICO Release Notes Version 0.0.5 (unreleased) ---------------------------- -• No significant changes yet. +• Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``. +• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.19. diff --git a/requirements.txt b/requirements.txt index ca4dc28b5..68ab1ebac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,8 +3,8 @@ scipy>=1.6.0 tifffile imageio>=2.17 matplotlib -jaxlib>=0.4.3,<=0.4.16 -jax>=0.4.3,<=0.4.16 +jaxlib>=0.4.3,<=0.4.19 +jax>=0.4.3,<=0.4.19 flax>=0.6.1,<=0.6.9 svmbir>=0.3.3 pyabel>=0.9.0 diff --git a/scico/linop/_matrix.py b/scico/linop/_matrix.py index 0f6f21daa..951c6957e 100644 --- a/scico/linop/_matrix.py +++ b/scico/linop/_matrix.py @@ -17,9 +17,9 @@ import numpy as np -import jax import jax.numpy as jnp from jax.dtypes import result_type +from jax.typing import ArrayLike import scico.numpy as snp @@ -65,7 +65,7 @@ def wrapper(a, b): class MatrixOperator(LinearOperator): """Linear operator implementing matrix multiplication.""" - def __init__(self, A: snp.Array, input_cols: int = 0): + def __init__(self, A: ArrayLike, input_cols: int = 0): """ Args: A: Dense array. The action of the created @@ -80,17 +80,16 @@ def __init__(self, A: snp.Array, input_cols: int = 0): self.A: snp.Array #: Dense array implementing this matrix # if A is an ndarray, make sure it gets converted to a jax array - if isinstance(A, jnp.ndarray): - self.A = A - elif isinstance(A, np.ndarray): - self.A = jax.device_put(A) # TODO: ensure_on_device? - else: + if not snp.util.is_arraylike(A): raise TypeError(f"Expected numpy or jax array, got {type(A)}.") + self.A = jnp.array(A) # Can only do rank-2 arrays if A.ndim != 2: raise TypeError(f"Expected a two-dimensional array, got array of shape {A.shape}.") + self.__array__ = A.__array__ # enables jnp.array(H) + if input_cols == 0: input_shape = A.shape[1] output_shape = A.shape[0] diff --git a/scico/numpy/util.py b/scico/numpy/util.py index 90620dc92..50fefdd4e 100644 --- a/scico/numpy/util.py +++ b/scico/numpy/util.py @@ -204,6 +204,21 @@ def shape_to_size(shape: Union[Shape, BlockShape]) -> int: return prod(shape) +def is_arraylike(x: Any) -> bool: + """Check if input is of type :class:`jax.ArrayLike`. + + `isinstance(x, jax.typing.ArrayLike)` does not work in Python < 3.10, + see https://jax.readthedocs.io/en/latest/jax.typing.html#jax-typing-best-practices. + + Args: + x: Object to be tested. + + Returns: + ``True`` if `x` is an ArrayLike, ``False`` otherwise. + """ + return isinstance(x, (np.ndarray, jax.Array)) or np.isscalar(x) + + def is_nested(x: Any) -> bool: """Check if input is a list/tuple containing at least one list/tuple. diff --git a/scico/optimize/_admmaux.py b/scico/optimize/_admmaux.py index 7a0ceb0b2..7a1c8710c 100644 --- a/scico/optimize/_admmaux.py +++ b/scico/optimize/_admmaux.py @@ -31,7 +31,7 @@ from scico.loss import SquaredL2Loss from scico.numpy import Array, BlockArray from scico.numpy.util import ensure_on_device, is_real_dtype -from scico.solver import ATADSolver, ConvATADSolver +from scico.solver import ConvATADSolver, MatrixATADSolver from scico.solver import cg as scico_cg from scico.solver import minimize @@ -296,14 +296,14 @@ class MatrixSubproblemSolver(LinearSubproblemSolver): \mb{u}^{(k)}_i) \;, which is solved by factorization of the left hand side of the - equation, using :class:`.ATADSolver`. + equation, using :class:`.MatrixATADSolver`. Attributes: admm (:class:`.ADMM`): ADMM solver object to which the solver is attached. solve_kwargs (dict): Dictionary of arguments for solver - :class:`.ATADSolver` initialization. + :class:`.MatrixATADSolver` initialization. """ def __init__(self, check_solve: bool = False, solve_kwargs: Optional[dict[str, Any]] = None): @@ -313,7 +313,7 @@ def __init__(self, check_solve: bool = False, solve_kwargs: Optional[dict[str, A check_solve: If ``True``, compute solver accuracy after each solve. solve_kwargs: Dictionary of arguments for solver - :class:`.ATADSolver` initialization. + :class:`.MatrixATADSolver` initialization. """ self.check_solve = check_solve default_solve_kwargs = {"cho_factor": False} @@ -352,7 +352,7 @@ def internal_init(self, admm: soa.ADMM): Csum = reduce( lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)] ) - self.solver = ATADSolver(A, Csum, W, **self.solve_kwargs) + self.solver = MatrixATADSolver(A, Csum, W, **self.solve_kwargs) def solve(self, x0: Array) -> Array: """Solve the ADMM step. diff --git a/scico/solver.py b/scico/solver.py index 5f0994246..f93cd710e 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -54,6 +54,7 @@ import jax import jax.experimental.host_callback as hcb +import jax.numpy as jnp import jax.scipy.linalg as jsl import scico.numpy as snp @@ -260,14 +261,13 @@ def fun(x0): def minimize_scalar( func: Callable, - bracket: Optional[Union[Sequence[float]]] = None, + bracket: Optional[Sequence[float]] = None, bounds: Optional[Sequence[float]] = None, args: Union[Tuple, Tuple[Any]] = (), method: str = "brent", tol: Optional[float] = None, options: Optional[dict] = None, ) -> spopt.OptimizeResult: - """Minimization of scalar function of one variable. Wrapper around :func:`scipy.optimize.minimize_scalar`. @@ -579,8 +579,8 @@ def golden( return r -class ATADSolver: - r"""Solver for linear system involving a symmetric product plus a diagonal. +class MatrixATADSolver: + r"""Solver for linear system involving a symmetric product. Solve a linear system of the form @@ -596,12 +596,18 @@ class ATADSolver: where :math:`A \in \mbb{R}^{M \times N}`, :math:`W \in \mbb{R}^{M \times M}` and - :math:`D \in \mbb{R}^{N \times N}`. The solution is computed by - factorization of matrix :math:`A^T W A + D` and solution via Gaussian - elimination. If :math:`D` is diagonal and :math:`N < M` (i.e. - :math:`A W A^T` is smaller than :math:`A^T W A`), then - :math:`A W A^T + D` is factorized and the original problem is solved - via the Woodbury matrix identity + :math:`D \in \mbb{R}^{N \times N}`. :math:`A` must be an instance of + :class:`.MatrixOperator` or an array; :math:`D` must be an instance + of :class:`.MatrixOperator`, :class:`.Diagonal`, or an array, and + :math:`W`, if specified, must be an instance of :class:`.Diagonal` + or an array. + + + The solution is computed by factorization of matrix + :math:`A^T W A + D` and solution via Gaussian elimination. If + :math:`D` is diagonal and :math:`N < M` (i.e. :math:`A W A^T` is + smaller than :math:`A^T W A`), then :math:`A W A^T + D` is factorized + and the original problem is solved via the Woodbury matrix identity .. math:: @@ -698,8 +704,12 @@ def __init__( r""" Args: A: Matrix :math:`A`. - D: Matrix :math:`D`. - W: Matrix :math:`W`. + D: Matrix :math:`D`. If a 2D array or :class:`MatrixOperator`, + specifies the 2D matrix :math:`D`. If 1D array or + :class:`Diagonal`, specifies the diagonal elements + of :math:`D`. + W: Matrix :math:`W`. Specifies the diagonal elements of + :math:`W`. Defaults to an array with unit entries. cho_factor: Flag indicating whether to use Cholesky (``True``) or LU (``False``) factorization. lower: Flag indicating whether lower (``True``) or upper @@ -708,16 +718,28 @@ def __init__( check_finite: Flag indicating whether the input array should be checked for ``Inf`` and ``NaN`` values. """ - if isinstance(A, MatrixOperator): - A = A.to_array() - if isinstance(D, MatrixOperator): - D = D.to_array() - elif isinstance(D, Diagonal): + A = jnp.array(A) + + if isinstance(D, Diagonal): D = D.diagonal + if not D.ndim == 1: + raise ValueError("If Diagonal, D should have a 1D diagonal.") + else: + D = jnp.array(D) + if not D.ndim in [1, 2]: + raise ValueError("If array or MatrixOperator, D should be 1D or 2D.") + if W is None: W = snp.ones(A.shape[0], dtype=A.dtype) elif isinstance(W, Diagonal): W = W.diagonal + if not W.ndim == 1: + raise ValueError("If Diagonal, W should have a 1D diagonal.") + elif not isinstance(W, Array): + raise TypeError( + f"Operator W is required to be None, a Diagonal, or an array; got a {type(W)}." + ) + self.A = A self.D = D self.W = W @@ -796,7 +818,7 @@ def accuracy(self, x: Array, b: Array) -> float: class ConvATADSolver: - r"""Solver for sum of convolutions plus diagonal linear system. + r"""Solver for a linear system involving a sum of convolutions. Solve a linear system of the form diff --git a/scico/test/linop/test_matrix.py b/scico/test/linop/test_matrix.py index 3f00b2c7a..178c1fce5 100644 --- a/scico/test/linop/test_matrix.py +++ b/scico/test/linop/test_matrix.py @@ -22,7 +22,6 @@ def setup_method(self, method): @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)]) def test_eval(self, matrix_shape, input_dtype, input_cols): - A, key = randn(matrix_shape, dtype=input_dtype, key=self.key) Ao = MatrixOperator(A, input_cols=input_cols) @@ -38,7 +37,6 @@ def test_eval(self, matrix_shape, input_dtype, input_cols): @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("matrix_shape", [(3, 3), (3, 4)]) def test_adjoint(self, matrix_shape, input_dtype, input_cols): - A, key = randn(matrix_shape, dtype=input_dtype, key=self.key) Ao = MatrixOperator(A, input_cols=input_cols) @@ -262,6 +260,10 @@ def test_to_array(self): assert isinstance(A_array, np.ndarray) np.testing.assert_allclose(A_array, A) + A_array = jnp.array(Ao) + assert isinstance(A_array, jax.Array) + np.testing.assert_allclose(A_array, A) + @pytest.mark.parametrize("ord", ["fro", 2]) @pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("keepdims", [True, False]) diff --git a/scico/test/test_solver.py b/scico/test/test_solver.py index d5b179b62..f220482df 100644 --- a/scico/test/test_solver.py +++ b/scico/test/test_solver.py @@ -319,7 +319,7 @@ def test_solve_atai(cho_factor, wide, weighted, alpha): D = alpha * snp.ones((A.shape[1],)) ATAD = A.T @ (Wa * A) + alpha * snp.identity(A.shape[1]) b = ATAD @ x0 - slv = solver.ATADSolver(A, D, W=W, cho_factor=cho_factor) + slv = solver.MatrixATADSolver(A, D, W=W, cho_factor=cho_factor) x1 = slv.solve(b) assert metric.rel_res(x0, x1) < 5e-5 @@ -338,7 +338,7 @@ def test_solve_aati(cho_factor, wide, alpha): D = alpha * snp.ones((A.shape[0],)) AATD = A @ A.T + alpha * snp.identity(A.shape[0]) b = AATD @ x0 - slv = solver.ATADSolver(A.T, D) + slv = solver.MatrixATADSolver(A.T, D) x1 = slv.solve(b) assert metric.rel_res(x0, x1) < 5e-5 @@ -365,7 +365,7 @@ def test_solve_atad(cho_factor, wide, vector): D = snp.abs(D) # only required for Cholesky, but improved accuracy for LU ATAD = A.T @ A + snp.diag(D) b = ATAD @ x0 - slv = solver.ATADSolver(A, D, cho_factor=cho_factor) + slv = solver.MatrixATADSolver(A, D, cho_factor=cho_factor) x1 = slv.solve(b) assert metric.rel_res(x0, x1) < 5e-5 assert slv.accuracy(x1, b) < 5e-5