From cf7b8761f1188ff3270350c0a47f37d1e2000a72 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Sun, 20 Oct 2024 03:15:02 +0200 Subject: [PATCH 1/8] pdhg --- src/mrpro/algorithms/optimizers/__init__.py | 1 + src/mrpro/algorithms/optimizers/pdhg.py | 174 ++++++++++++++++++++ 2 files changed, 175 insertions(+) create mode 100644 src/mrpro/algorithms/optimizers/pdhg.py diff --git a/src/mrpro/algorithms/optimizers/__init__.py b/src/mrpro/algorithms/optimizers/__init__.py index 5076e454..8eb97ad9 100644 --- a/src/mrpro/algorithms/optimizers/__init__.py +++ b/src/mrpro/algorithms/optimizers/__init__.py @@ -2,3 +2,4 @@ from mrpro.algorithms.optimizers.adam import adam from mrpro.algorithms.optimizers.cg import cg from mrpro.algorithms.optimizers.lbfgs import lbfgs +from mrpro.algorithms.optimizers.pdhg import pdhg diff --git a/src/mrpro/algorithms/optimizers/pdhg.py b/src/mrpro/algorithms/optimizers/pdhg.py new file mode 100644 index 00000000..c43687a2 --- /dev/null +++ b/src/mrpro/algorithms/optimizers/pdhg.py @@ -0,0 +1,174 @@ +"""Primal-Dual Hybrid Gradient Algorithm (PDHG).""" + +from __future__ import annotations + +import warnings +from collections.abc import Callable, Sequence +from dataclasses import dataclass + +import torch + +from mrpro.algorithms.optimizers import OptimizerStatus +from mrpro.operators import ( + IdentityOp, + LinearOperator, + LinearOperatorMatrix, + ProximableFunctional, + ProximableFunctionalSeparableSum, +) +from mrpro.operators.functionals import ZeroFunctional + + +@dataclass +class PDHGStatus(OptimizerStatus): + """Status of the PDHG algorithm.""" + + objective: Callable[[*tuple[torch.Tensor, ...]], torch.Tensor] + dual_stepsize: float | torch.Tensor + primal_stepsize: float | torch.Tensor + relaxation: float | torch.Tensor + duals: Sequence[torch.Tensor] + relaxed: Sequence[torch.Tensor] + + +def pdhg( + initial_values: Sequence[torch.Tensor], + f: ProximableFunctionalSeparableSum | ProximableFunctional | None = None, + g: ProximableFunctionalSeparableSum | ProximableFunctional | None = None, + operator: LinearOperator | LinearOperatorMatrix | None = None, + n_iterations: int = 10, + primal_stepsize: float | None = None, + dual_stepsize: float | None = None, + relaxation: float = 1.0, + initial_relaxed: Sequence[torch.Tensor] | None = None, + initial_duals: Sequence[torch.Tensor] | None = None, + callback: Callable[[PDHGStatus], None] | None = None, +) -> tuple[torch.Tensor, ...]: + r"""Primal-Dual Hybrid Gradient Algorithm (PDHG). + + Solves the minimization problem + :math:`\min_x g(x) + f(A x)` + with linear operator A and proximable functionals f and g. + + The operator is supplied as a matrix (tuple of tuples) of linear operators, + f and g are supplied as tuples of proximable functionals interpreted as separable sums. + + Thus, problem solved is + :math:`\min_x \sum_i,j g_j(x_j) + f_i(A_ij x_j)`. + + If neither primal nor dual step size are not supplied, they are chose as :math:`1/||A||_{op}`. + If either is supplied, the other is chosen such that primal_stepsize*dual_stepsize = :math:`1/||A||_{op}^2` + + For a warm start, the relaxed solution x_relaxed and dual variables can be supplied. + These might be obtained from the Status object of a previous run. + + Parameters + ---------- + initial_values + initial guess + f + tuple of proximable functionals interpreted as a separable sum + g + tuple of proximable functionals interpreted as a separable sum + operator + matrix of linear operators + n_iterations + number of iterations + dual_stepsize + dual step size + primal_stepsize + primal step size + relaxation + relaxation parameter, 1.0 is no relaxation + initial_relaxed + relaxed primals, used for warm start + initial_duals + dual variables, used for warm start + callback + callback function called after each iteration + """ + if f is None and g is None: + warnings.warn( + 'Both f and g are None. The objective is constant. Returning x0 as a possible solution', stacklevel=2 + ) + return tuple(initial_values) + + if operator is None: + rows = len(f) if f is not None else 1 + cols = len(g) if g is not None else 1 + if rows != cols: + raise ValueError('If operator is None, the number of elements in f and g should be the same') + operator_matrix = LinearOperatorMatrix.from_diagonal(*((IdentityOp(),) * rows)) + else: + if isinstance(operator, LinearOperator): + operator_matrix = LinearOperatorMatrix.from_diagonal(operator) + else: + operator_matrix = operator + rows, cols = operator_matrix.shape + if f is not None and len(f) != rows: + raise ValueError('Number of rows in operator does not match number of functionals in f') + if g is not None and len(g) != cols: + raise ValueError('Number of columns in operator does not match number of functionals in g') + + if f is None: + f_sum = ProximableFunctionalSeparableSum(*(ZeroFunctional(),) * rows) + elif isinstance(f, ProximableFunctional): + f_sum = ProximableFunctionalSeparableSum(f) + else: + f_sum = f + + if g is None: + g_sum = ProximableFunctionalSeparableSum(*(ZeroFunctional(),) * cols) + elif isinstance(g, ProximableFunctional): + g_sum = ProximableFunctionalSeparableSum(g) + else: + g_sum = g + + if primal_stepsize is None or dual_stepsize is None: + # choose primal and dual step size such that their product is 1/|operator|**2 + # to ensure convergence + operator_norm = operator_matrix.operator_norm(initial_values) + if primal_stepsize is None and dual_stepsize is None: + primal_stepsize_ = dual_stepsize = 1.0 / operator_norm + elif primal_stepsize is None: + primal_stepsize_ = 1 / (operator_norm * dual_stepsize) + elif dual_stepsize is None: + dual_stepsize_ = 1 / (operator_norm * primal_stepsize) + else: + primal_stepsize_ = primal_stepsize + dual_stepsize_ = dual_stepsize + + primals_relaxed = initial_values if initial_relaxed is None else initial_relaxed + duals = (0 * operator_matrix)(initial_values) if initial_duals is None else initial_duals + + if len(duals) != rows: + raise ValueError('if dual y is supplied, it should be a tuple of same length as the tuple of g') + + primals = initial_values + for i in range(n_iterations): + duals = tuple( + dual + dual_stepsize_ * step for dual, step in zip(duals, operator_matrix(primals_relaxed), strict=False) + ) + duals = f_sum.prox_convex_conj(*duals, sigma=dual_stepsize_) + + primals_new = tuple( + primal - primal_stepsize_ * step for primal, step in zip(primals, operator_matrix.H(duals), strict=False) + ) + primals_new = g_sum.prox(*primals_new, sigma=primal_stepsize_) + primals_relaxed = [ + torch.lerp(primal, primal_new, relaxation) for primal, primal_new in zip(primals, primals_new, strict=False) + ] + primals = primals_new + if callback is not None: + status = PDHGStatus( + iteration_number=i, + dual_stepsize=dual_stepsize_, + primal_stepsize=primal_stepsize_, + relaxation=relaxation, + duals=duals, + solution=tuple(primals), + relaxed=primals_relaxed, + objective=lambda *x: f_sum.forward(*operator_matrix(*x))[0] + g_sum.forward(*x)[0], + ) + callback(status) + return tuple(primals) From 235f76c039d4caea6d1944a0bbd0d53a0fe1f0da Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 12 Nov 2024 18:27:24 +0100 Subject: [PATCH 2/8] fix import --- src/mrpro/algorithms/optimizers/pdhg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/algorithms/optimizers/pdhg.py b/src/mrpro/algorithms/optimizers/pdhg.py index c43687a2..f234927b 100644 --- a/src/mrpro/algorithms/optimizers/pdhg.py +++ b/src/mrpro/algorithms/optimizers/pdhg.py @@ -8,7 +8,7 @@ import torch -from mrpro.algorithms.optimizers import OptimizerStatus +from mrpro.algorithms.optimizers.OptimizerStatus import OptimizerStatus from mrpro.operators import ( IdentityOp, LinearOperator, From eb177f42b9d96dc4a39f456efbcd20d2faf5fe10 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Tue, 12 Nov 2024 18:31:40 +0100 Subject: [PATCH 3/8] comments --- src/mrpro/algorithms/optimizers/pdhg.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/mrpro/algorithms/optimizers/pdhg.py b/src/mrpro/algorithms/optimizers/pdhg.py index f234927b..ccbb8c11 100644 --- a/src/mrpro/algorithms/optimizers/pdhg.py +++ b/src/mrpro/algorithms/optimizers/pdhg.py @@ -94,6 +94,7 @@ def pdhg( return tuple(initial_values) if operator is None: + # Use identity operator if no operator is supplied rows = len(f) if f is not None else 1 cols = len(g) if g is not None else 1 if rows != cols: @@ -101,6 +102,7 @@ def pdhg( operator_matrix = LinearOperatorMatrix.from_diagonal(*((IdentityOp(),) * rows)) else: if isinstance(operator, LinearOperator): + # We allways use a matrix of operators for homogeneous handling operator_matrix = LinearOperatorMatrix.from_diagonal(operator) else: operator_matrix = operator @@ -111,6 +113,7 @@ def pdhg( raise ValueError('Number of columns in operator does not match number of functionals in g') if f is None: + # We always use a separable sum for homogeneous handling, even if it is just a ZeroFunctional f_sum = ProximableFunctionalSeparableSum(*(ZeroFunctional(),) * rows) elif isinstance(f, ProximableFunctional): f_sum = ProximableFunctionalSeparableSum(f) @@ -118,6 +121,7 @@ def pdhg( f_sum = f if g is None: + # We always use a separable sum for homogeneous handling, even if it is just a ZeroFunctional g_sum = ProximableFunctionalSeparableSum(*(ZeroFunctional(),) * cols) elif isinstance(g, ProximableFunctional): g_sum = ProximableFunctionalSeparableSum(g) From 23d7f804c1f48bdc718e7e81296d572127d3f567 Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Fri, 15 Nov 2024 00:47:11 +0100 Subject: [PATCH 4/8] fix --- src/mrpro/algorithms/optimizers/pdhg.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/mrpro/algorithms/optimizers/pdhg.py b/src/mrpro/algorithms/optimizers/pdhg.py index ccbb8c11..0df4a8c3 100644 --- a/src/mrpro/algorithms/optimizers/pdhg.py +++ b/src/mrpro/algorithms/optimizers/pdhg.py @@ -95,14 +95,14 @@ def pdhg( if operator is None: # Use identity operator if no operator is supplied - rows = len(f) if f is not None else 1 - cols = len(g) if g is not None else 1 + rows = len(f) if isinstance(f, ProximableFunctionalSeparableSum) else 1 + cols = len(g) if isinstance(g, ProximableFunctionalSeparableSum) else 1 if rows != cols: raise ValueError('If operator is None, the number of elements in f and g should be the same') operator_matrix = LinearOperatorMatrix.from_diagonal(*((IdentityOp(),) * rows)) else: if isinstance(operator, LinearOperator): - # We allways use a matrix of operators for homogeneous handling + # We always use a matrix of operators for homogeneous handling operator_matrix = LinearOperatorMatrix.from_diagonal(operator) else: operator_matrix = operator @@ -131,9 +131,9 @@ def pdhg( if primal_stepsize is None or dual_stepsize is None: # choose primal and dual step size such that their product is 1/|operator|**2 # to ensure convergence - operator_norm = operator_matrix.operator_norm(initial_values) + operator_norm = operator_matrix.operator_norm(*initial_values) if primal_stepsize is None and dual_stepsize is None: - primal_stepsize_ = dual_stepsize = 1.0 / operator_norm + primal_stepsize_ = dual_stepsize_ = 1.0 / operator_norm elif primal_stepsize is None: primal_stepsize_ = 1 / (operator_norm * dual_stepsize) elif dual_stepsize is None: @@ -151,12 +151,12 @@ def pdhg( primals = initial_values for i in range(n_iterations): duals = tuple( - dual + dual_stepsize_ * step for dual, step in zip(duals, operator_matrix(primals_relaxed), strict=False) + dual + dual_stepsize_ * step for dual, step in zip(duals, operator_matrix(*primals_relaxed), strict=False) ) duals = f_sum.prox_convex_conj(*duals, sigma=dual_stepsize_) primals_new = tuple( - primal - primal_stepsize_ * step for primal, step in zip(primals, operator_matrix.H(duals), strict=False) + primal - primal_stepsize_ * step for primal, step in zip(primals, operator_matrix.H(*duals), strict=False) ) primals_new = g_sum.prox(*primals_new, sigma=primal_stepsize_) primals_relaxed = [ From f4d4a9b23a8f203aeb3e5516b226a002421dcb32 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Mon, 18 Nov 2024 14:33:03 +0100 Subject: [PATCH 5/8] fix --- src/mrpro/algorithms/optimizers/pdhg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrpro/algorithms/optimizers/pdhg.py b/src/mrpro/algorithms/optimizers/pdhg.py index 0df4a8c3..c285516a 100644 --- a/src/mrpro/algorithms/optimizers/pdhg.py +++ b/src/mrpro/algorithms/optimizers/pdhg.py @@ -143,7 +143,7 @@ def pdhg( dual_stepsize_ = dual_stepsize primals_relaxed = initial_values if initial_relaxed is None else initial_relaxed - duals = (0 * operator_matrix)(initial_values) if initial_duals is None else initial_duals + duals = (0 * operator_matrix)(*initial_values) if initial_duals is None else initial_duals if len(duals) != rows: raise ValueError('if dual y is supplied, it should be a tuple of same length as the tuple of g') From ca10bbd4197434b03c89017d45c1dc3f2fae8659 Mon Sep 17 00:00:00 2001 From: koflera Date: Mon, 18 Nov 2024 14:34:17 +0100 Subject: [PATCH 6/8] tests --- tests/algorithms/test_pdhg.py | 142 ++++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 tests/algorithms/test_pdhg.py diff --git a/tests/algorithms/test_pdhg.py b/tests/algorithms/test_pdhg.py new file mode 100644 index 00000000..7bb25a08 --- /dev/null +++ b/tests/algorithms/test_pdhg.py @@ -0,0 +1,142 @@ +"""Tests for PDHG.""" + +import torch +from mrpro.algorithms.optimizers import pdhg +from mrpro.operators import FastFourierOp, IdentityOp, LinearOperatorMatrix, ProximableFunctionalSeparableSum, WaveletOp +from mrpro.operators.functionals import L1Norm, L1NormViewAsReal, L2NormSquared, ZeroFunctional +from tests import RandomGenerator + + +def test_l2_l1_identification1(): + """Set up the problem min_x 1/2*||x - y||_2^2 + lambda * ||x||_1, + which has a closed form solution given by the soft-thresholding operator. + + Here, for f(K(x)) + g(x), we used the identification + f(x) = 1/2 * || p - y ||_2^2 + g(x) = lambda * ||x||_1 + K = Id + """ + random_generator = RandomGenerator(seed=0) + + data_shape = (160, 160) + data = random_generator.float32_tensor(size=data_shape) + + regularization_parameter = 0.1 + + l2 = 0.5 * L2NormSquared(target=data, divide_by_n=False) + l1 = regularization_parameter * L1Norm(divide_by_n=False) + + f = l2 + g = l1 + operator = IdentityOp() + + initial_values = (random_generator.float32_tensor(size=data_shape),) + expected = torch.nn.functional.softshrink(data, regularization_parameter) + + n_iterations = 64 + (pdhg_solution,) = pdhg(f=f, g=g, operator=operator, initial_values=initial_values, n_iterations=n_iterations) + torch.testing.assert_close(pdhg_solution, expected, rtol=5e-4, atol=5e-4) + + +def test_l2_l1_identification2(): + """Set up the problem min_x 1/2*||x - y||_2^2 + lambda * ||x||_1, + which has a closed form solution given by the soft-thresholding operator. + + Here, for f(K(x)) + g(x), we used the identification + f(p,q) = f1(p) + f2(q) = 1/2 * || p - y ||_2^2 + lambda * ||q||_1 + g(x) = 0 for all x, + K = [Id, Id]^T + """ + random_generator = RandomGenerator(seed=0) + + data_shape = (32, 64, 64) + data = random_generator.float32_tensor(size=data_shape) + + regularization_parameter = 0.5 + + l2 = 0.5 * L2NormSquared(target=data, divide_by_n=False) + l1 = regularization_parameter * L1Norm(divide_by_n=False) + + f = ProximableFunctionalSeparableSum(l2, l1) + g = ZeroFunctional() + operator = LinearOperatorMatrix(((IdentityOp(),), (IdentityOp(),))) + + initial_values = (random_generator.float32_tensor(size=data_shape),) + expected = torch.nn.functional.softshrink(data, regularization_parameter) + + n_iterations = 64 + (pdhg_solution,) = pdhg(f=f, g=g, operator=operator, initial_values=initial_values, n_iterations=n_iterations) + torch.testing.assert_close(pdhg_solution, expected, rtol=5e-4, atol=5e-4) + + +def test_fourier_l2_l1_(): + """Set up the problem min_x 1/2*|| Fx - y||_2^2 + lambda * ||x||_1, + where F is the full FFT sampled on a Cartesian grid. Thus, again, the + problem has a closed-form solution given by soft-thresholding. + """ + random_generator = RandomGenerator(seed=0) + + image_shape = (32, 48, 48) + image = random_generator.complex64_tensor(size=image_shape) + + fourier_op = FastFourierOp(dim=(-3, -2, -1)) + + (data,) = fourier_op(image) + + regularization_parameter = 0.5 + + l2 = 0.5 * L2NormSquared(target=data, divide_by_n=False) + l1 = regularization_parameter * L1NormViewAsReal(divide_by_n=False) + + f = ProximableFunctionalSeparableSum(l2, l1) + g = ZeroFunctional() + operator = LinearOperatorMatrix(((fourier_op,), (IdentityOp(),))) + + initial_values = (random_generator.complex64_tensor(size=image_shape),) + expected = torch.view_as_complex( + torch.nn.functional.softshrink(torch.view_as_real(fourier_op.H(data)[0]), regularization_parameter) + ) + + n_iterations = 128 + (pdhg_solution,) = pdhg(f=f, g=g, operator=operator, initial_values=initial_values, n_iterations=n_iterations) + torch.testing.assert_close(pdhg_solution, expected, rtol=5e-4, atol=5e-4) + + +def test_fourier_l2_wavelet_l1_(): + """Set up the problem min_x 1/2*|| Fx - y||_2^2 + lambda * || W x||_1, + where F is the full FFT sampled on a Cartesian grid and W a wavelet transform. + Because both F and W are invertible, the problem has a closed-form solution + obtainable by soft-thresholding. + """ + random_generator = RandomGenerator(seed=0) + + image_shape = (6, 32, 32) + image = random_generator.complex64_tensor(size=image_shape) + + dim = (-3, -2, -1) + fourier_op = FastFourierOp(dim=dim) + wavelet_op = WaveletOp(domain_shape=image_shape, dim=dim) + + (data,) = fourier_op(image) + + regularization_parameter = 0.5 + + l2 = 0.5 * L2NormSquared(target=data, divide_by_n=False) + l1 = regularization_parameter * L1NormViewAsReal(divide_by_n=False) + + f = ProximableFunctionalSeparableSum(l2, l1) + g = ZeroFunctional() + operator = LinearOperatorMatrix(((fourier_op,), (wavelet_op,))) + + initial_values = (random_generator.complex64_tensor(size=image_shape),) + expected = wavelet_op.H( + torch.view_as_complex( + torch.nn.functional.softshrink( + torch.view_as_real(wavelet_op(fourier_op.H(data)[0])[0]), regularization_parameter + ) + ) + )[0] + + n_iterations = 128 + (pdhg_solution,) = pdhg(f=f, g=g, operator=operator, initial_values=initial_values, n_iterations=n_iterations) + torch.testing.assert_close(pdhg_solution, expected, rtol=5e-4, atol=5e-4) From 86b684ae9ed7e715770727f11bab190fb81dce9f Mon Sep 17 00:00:00 2001 From: Andreas Kofler Date: Mon, 18 Nov 2024 14:43:18 +0100 Subject: [PATCH 7/8] fix --- src/mrpro/algorithms/optimizers/pdhg.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mrpro/algorithms/optimizers/pdhg.py b/src/mrpro/algorithms/optimizers/pdhg.py index c285516a..186f0aea 100644 --- a/src/mrpro/algorithms/optimizers/pdhg.py +++ b/src/mrpro/algorithms/optimizers/pdhg.py @@ -107,16 +107,14 @@ def pdhg( else: operator_matrix = operator rows, cols = operator_matrix.shape - if f is not None and len(f) != rows: - raise ValueError('Number of rows in operator does not match number of functionals in f') - if g is not None and len(g) != cols: - raise ValueError('Number of columns in operator does not match number of functionals in g') if f is None: # We always use a separable sum for homogeneous handling, even if it is just a ZeroFunctional f_sum = ProximableFunctionalSeparableSum(*(ZeroFunctional(),) * rows) elif isinstance(f, ProximableFunctional): f_sum = ProximableFunctionalSeparableSum(f) + if len(f) != rows: + raise ValueError('Number of rows in operator does not match number of functionals in f') else: f_sum = f @@ -125,6 +123,8 @@ def pdhg( g_sum = ProximableFunctionalSeparableSum(*(ZeroFunctional(),) * cols) elif isinstance(g, ProximableFunctional): g_sum = ProximableFunctionalSeparableSum(g) + if len(g) != cols: + raise ValueError('Number of columns in operator does not match number of functionals in g') else: g_sum = g From 930bc9283ca55e2262e4c48058a4cff1dbf03717 Mon Sep 17 00:00:00 2001 From: Andreas Kofler Date: Mon, 18 Nov 2024 14:49:48 +0100 Subject: [PATCH 8/8] fix the fix --- src/mrpro/algorithms/optimizers/pdhg.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mrpro/algorithms/optimizers/pdhg.py b/src/mrpro/algorithms/optimizers/pdhg.py index 186f0aea..573129f2 100644 --- a/src/mrpro/algorithms/optimizers/pdhg.py +++ b/src/mrpro/algorithms/optimizers/pdhg.py @@ -113,8 +113,8 @@ def pdhg( f_sum = ProximableFunctionalSeparableSum(*(ZeroFunctional(),) * rows) elif isinstance(f, ProximableFunctional): f_sum = ProximableFunctionalSeparableSum(f) - if len(f) != rows: - raise ValueError('Number of rows in operator does not match number of functionals in f') + elif len(f) != rows: + raise ValueError('Number of rows in operator does not match number of functionals in f') else: f_sum = f @@ -123,8 +123,8 @@ def pdhg( g_sum = ProximableFunctionalSeparableSum(*(ZeroFunctional(),) * cols) elif isinstance(g, ProximableFunctional): g_sum = ProximableFunctionalSeparableSum(g) - if len(g) != cols: - raise ValueError('Number of columns in operator does not match number of functionals in g') + elif len(g) != cols: + raise ValueError('Number of columns in operator does not match number of functionals in g') else: g_sum = g