Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Primal Dual Hybrid Gradient algorithm #426

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion src/mrpro/algorithms/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from mrpro.algorithms.optimizers.adam import adam
from mrpro.algorithms.optimizers.cg import cg
from mrpro.algorithms.optimizers.lbfgs import lbfgs
__all__ = ["OptimizerStatus", "adam", "cg", "lbfgs"]
from mrpro.algorithms.optimizers.pdhg import pdhg
__all__ = ["OptimizerStatus", "adam", "cg", "lbfgs", "pdhg"]
178 changes: 178 additions & 0 deletions src/mrpro/algorithms/optimizers/pdhg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Primal-Dual Hybrid Gradient Algorithm (PDHG)."""

from __future__ import annotations

import warnings
from collections.abc import Callable, Sequence
from dataclasses import dataclass

import torch

from mrpro.algorithms.optimizers.OptimizerStatus import OptimizerStatus
from mrpro.operators import (
IdentityOp,
LinearOperator,
LinearOperatorMatrix,
ProximableFunctional,
ProximableFunctionalSeparableSum,
)
from mrpro.operators.functionals import ZeroFunctional


@dataclass
class PDHGStatus(OptimizerStatus):
"""Status of the PDHG algorithm."""

objective: Callable[[*tuple[torch.Tensor, ...]], torch.Tensor]
dual_stepsize: float | torch.Tensor
primal_stepsize: float | torch.Tensor
relaxation: float | torch.Tensor
duals: Sequence[torch.Tensor]
relaxed: Sequence[torch.Tensor]


def pdhg(
initial_values: Sequence[torch.Tensor],
f: ProximableFunctionalSeparableSum | ProximableFunctional | None = None,
g: ProximableFunctionalSeparableSum | ProximableFunctional | None = None,
operator: LinearOperator | LinearOperatorMatrix | None = None,
n_iterations: int = 10,
primal_stepsize: float | None = None,
dual_stepsize: float | None = None,
relaxation: float = 1.0,
initial_relaxed: Sequence[torch.Tensor] | None = None,
initial_duals: Sequence[torch.Tensor] | None = None,
callback: Callable[[PDHGStatus], None] | None = None,
) -> tuple[torch.Tensor, ...]:
r"""Primal-Dual Hybrid Gradient Algorithm (PDHG).

Solves the minimization problem
:math:`\min_x g(x) + f(A x)`
with linear operator A and proximable functionals f and g.

The operator is supplied as a matrix (tuple of tuples) of linear operators,
f and g are supplied as tuples of proximable functionals interpreted as separable sums.

Thus, problem solved is
:math:`\min_x \sum_i,j g_j(x_j) + f_i(A_ij x_j)`.

If neither primal nor dual step size are not supplied, they are chose as :math:`1/||A||_{op}`.
If either is supplied, the other is chosen such that primal_stepsize*dual_stepsize = :math:`1/||A||_{op}^2`

For a warm start, the relaxed solution x_relaxed and dual variables can be supplied.
These might be obtained from the Status object of a previous run.

Parameters
----------
initial_values
initial guess
f
tuple of proximable functionals interpreted as a separable sum
g
tuple of proximable functionals interpreted as a separable sum
operator
matrix of linear operators
n_iterations
number of iterations
dual_stepsize
dual step size
primal_stepsize
primal step size
relaxation
relaxation parameter, 1.0 is no relaxation
initial_relaxed
relaxed primals, used for warm start
initial_duals
dual variables, used for warm start
callback
callback function called after each iteration
"""
if f is None and g is None:
warnings.warn(
'Both f and g are None. The objective is constant. Returning x0 as a possible solution', stacklevel=2
)
return tuple(initial_values)

if operator is None:
# Use identity operator if no operator is supplied
rows = len(f) if isinstance(f, ProximableFunctionalSeparableSum) else 1
cols = len(g) if isinstance(g, ProximableFunctionalSeparableSum) else 1
if rows != cols:
raise ValueError('If operator is None, the number of elements in f and g should be the same')
operator_matrix = LinearOperatorMatrix.from_diagonal(*((IdentityOp(),) * rows))
else:
if isinstance(operator, LinearOperator):
# We always use a matrix of operators for homogeneous handling
operator_matrix = LinearOperatorMatrix.from_diagonal(operator)
else:
operator_matrix = operator
rows, cols = operator_matrix.shape

if f is None:
# We always use a separable sum for homogeneous handling, even if it is just a ZeroFunctional
f_sum = ProximableFunctionalSeparableSum(*(ZeroFunctional(),) * rows)
elif isinstance(f, ProximableFunctional):
f_sum = ProximableFunctionalSeparableSum(f)
elif len(f) != rows:
raise ValueError('Number of rows in operator does not match number of functionals in f')
else:
f_sum = f

if g is None:
# We always use a separable sum for homogeneous handling, even if it is just a ZeroFunctional
g_sum = ProximableFunctionalSeparableSum(*(ZeroFunctional(),) * cols)
elif isinstance(g, ProximableFunctional):
g_sum = ProximableFunctionalSeparableSum(g)
elif len(g) != cols:
raise ValueError('Number of columns in operator does not match number of functionals in g')
else:
g_sum = g

if primal_stepsize is None or dual_stepsize is None:
# choose primal and dual step size such that their product is 1/|operator|**2
# to ensure convergence
operator_norm = operator_matrix.operator_norm(*initial_values)
if primal_stepsize is None and dual_stepsize is None:
primal_stepsize_ = dual_stepsize_ = 1.0 / operator_norm
elif primal_stepsize is None:
primal_stepsize_ = 1 / (operator_norm * dual_stepsize)
elif dual_stepsize is None:
dual_stepsize_ = 1 / (operator_norm * primal_stepsize)
else:
primal_stepsize_ = primal_stepsize
dual_stepsize_ = dual_stepsize

primals_relaxed = initial_values if initial_relaxed is None else initial_relaxed
duals = (0 * operator_matrix)(*initial_values) if initial_duals is None else initial_duals

if len(duals) != rows:
raise ValueError('if dual y is supplied, it should be a tuple of same length as the tuple of g')

primals = initial_values
for i in range(n_iterations):
duals = tuple(
dual + dual_stepsize_ * step for dual, step in zip(duals, operator_matrix(*primals_relaxed), strict=False)
)
duals = f_sum.prox_convex_conj(*duals, sigma=dual_stepsize_)

primals_new = tuple(
primal - primal_stepsize_ * step for primal, step in zip(primals, operator_matrix.H(*duals), strict=False)
)
primals_new = g_sum.prox(*primals_new, sigma=primal_stepsize_)
primals_relaxed = [
torch.lerp(primal, primal_new, relaxation) for primal, primal_new in zip(primals, primals_new, strict=False)
]
primals = primals_new
if callback is not None:
status = PDHGStatus(
iteration_number=i,
dual_stepsize=dual_stepsize_,
primal_stepsize=primal_stepsize_,
relaxation=relaxation,
duals=duals,
solution=tuple(primals),
relaxed=primals_relaxed,
objective=lambda *x: f_sum.forward(*operator_matrix(*x))[0] + g_sum.forward(*x)[0],
)
callback(status)
return tuple(primals)
142 changes: 142 additions & 0 deletions tests/algorithms/test_pdhg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Tests for PDHG."""

import torch
from mrpro.algorithms.optimizers import pdhg
from mrpro.operators import FastFourierOp, IdentityOp, LinearOperatorMatrix, ProximableFunctionalSeparableSum, WaveletOp
from mrpro.operators.functionals import L1Norm, L1NormViewAsReal, L2NormSquared, ZeroFunctional
from tests import RandomGenerator


def test_l2_l1_identification1():
"""Set up the problem min_x 1/2*||x - y||_2^2 + lambda * ||x||_1,
which has a closed form solution given by the soft-thresholding operator.

Here, for f(K(x)) + g(x), we used the identification
f(x) = 1/2 * || p - y ||_2^2
g(x) = lambda * ||x||_1
K = Id
"""
random_generator = RandomGenerator(seed=0)

data_shape = (160, 160)
data = random_generator.float32_tensor(size=data_shape)

regularization_parameter = 0.1

l2 = 0.5 * L2NormSquared(target=data, divide_by_n=False)
l1 = regularization_parameter * L1Norm(divide_by_n=False)

f = l2
g = l1
operator = IdentityOp()

initial_values = (random_generator.float32_tensor(size=data_shape),)
expected = torch.nn.functional.softshrink(data, regularization_parameter)

n_iterations = 64
(pdhg_solution,) = pdhg(f=f, g=g, operator=operator, initial_values=initial_values, n_iterations=n_iterations)
torch.testing.assert_close(pdhg_solution, expected, rtol=5e-4, atol=5e-4)


def test_l2_l1_identification2():
"""Set up the problem min_x 1/2*||x - y||_2^2 + lambda * ||x||_1,
which has a closed form solution given by the soft-thresholding operator.

Here, for f(K(x)) + g(x), we used the identification
f(p,q) = f1(p) + f2(q) = 1/2 * || p - y ||_2^2 + lambda * ||q||_1
g(x) = 0 for all x,
K = [Id, Id]^T
"""
random_generator = RandomGenerator(seed=0)

data_shape = (32, 64, 64)
data = random_generator.float32_tensor(size=data_shape)

regularization_parameter = 0.5

l2 = 0.5 * L2NormSquared(target=data, divide_by_n=False)
l1 = regularization_parameter * L1Norm(divide_by_n=False)

f = ProximableFunctionalSeparableSum(l2, l1)
g = ZeroFunctional()
operator = LinearOperatorMatrix(((IdentityOp(),), (IdentityOp(),)))

initial_values = (random_generator.float32_tensor(size=data_shape),)
expected = torch.nn.functional.softshrink(data, regularization_parameter)

n_iterations = 64
(pdhg_solution,) = pdhg(f=f, g=g, operator=operator, initial_values=initial_values, n_iterations=n_iterations)
torch.testing.assert_close(pdhg_solution, expected, rtol=5e-4, atol=5e-4)

Check failure on line 69 in tests/algorithms/test_pdhg.py

View workflow job for this annotation

GitHub Actions / Run Tests and Coverage Report (mrpro_py311)

test_l2_l1_identification2 AssertionError: Tensor-likes are not close! Mismatched elements: 30184 / 131072 (23.0%) Greatest absolute difference: 0.001042758347466588 at index (29, 40, 21) (up to 0.0005 allowed) Greatest relative difference: inf at index (0, 0, 2) (up to 0.0005 allowed)

Check failure on line 69 in tests/algorithms/test_pdhg.py

View workflow job for this annotation

GitHub Actions / Run Tests and Coverage Report (mrpro_py312)

test_l2_l1_identification2 AssertionError: Tensor-likes are not close! Mismatched elements: 30184 / 131072 (23.0%) Greatest absolute difference: 0.001042758347466588 at index (29, 40, 21) (up to 0.0005 allowed) Greatest relative difference: inf at index (0, 0, 2) (up to 0.0005 allowed)

Check failure on line 69 in tests/algorithms/test_pdhg.py

View workflow job for this annotation

GitHub Actions / Run Tests and Coverage Report (mrpro_py310)

test_l2_l1_identification2 AssertionError: Tensor-likes are not close! Mismatched elements: 30184 / 131072 (23.0%) Greatest absolute difference: 0.001042758347466588 at index (29, 40, 21) (up to 0.0005 allowed) Greatest relative difference: inf at index (0, 0, 2) (up to 0.0005 allowed)


def test_fourier_l2_l1_():
"""Set up the problem min_x 1/2*|| Fx - y||_2^2 + lambda * ||x||_1,
where F is the full FFT sampled on a Cartesian grid. Thus, again, the
problem has a closed-form solution given by soft-thresholding.
"""
random_generator = RandomGenerator(seed=0)

image_shape = (32, 48, 48)
image = random_generator.complex64_tensor(size=image_shape)

fourier_op = FastFourierOp(dim=(-3, -2, -1))

(data,) = fourier_op(image)

regularization_parameter = 0.5

l2 = 0.5 * L2NormSquared(target=data, divide_by_n=False)
l1 = regularization_parameter * L1NormViewAsReal(divide_by_n=False)

f = ProximableFunctionalSeparableSum(l2, l1)
g = ZeroFunctional()
operator = LinearOperatorMatrix(((fourier_op,), (IdentityOp(),)))

initial_values = (random_generator.complex64_tensor(size=image_shape),)
expected = torch.view_as_complex(
torch.nn.functional.softshrink(torch.view_as_real(fourier_op.H(data)[0]), regularization_parameter)
)

n_iterations = 128
(pdhg_solution,) = pdhg(f=f, g=g, operator=operator, initial_values=initial_values, n_iterations=n_iterations)
torch.testing.assert_close(pdhg_solution, expected, rtol=5e-4, atol=5e-4)


def test_fourier_l2_wavelet_l1_():
"""Set up the problem min_x 1/2*|| Fx - y||_2^2 + lambda * || W x||_1,
where F is the full FFT sampled on a Cartesian grid and W a wavelet transform.
Because both F and W are invertible, the problem has a closed-form solution
obtainable by soft-thresholding.
"""
random_generator = RandomGenerator(seed=0)

image_shape = (6, 32, 32)
image = random_generator.complex64_tensor(size=image_shape)

dim = (-3, -2, -1)
fourier_op = FastFourierOp(dim=dim)
wavelet_op = WaveletOp(domain_shape=image_shape, dim=dim)

(data,) = fourier_op(image)

regularization_parameter = 0.5

l2 = 0.5 * L2NormSquared(target=data, divide_by_n=False)
l1 = regularization_parameter * L1NormViewAsReal(divide_by_n=False)

f = ProximableFunctionalSeparableSum(l2, l1)
g = ZeroFunctional()
operator = LinearOperatorMatrix(((fourier_op,), (wavelet_op,)))

initial_values = (random_generator.complex64_tensor(size=image_shape),)
expected = wavelet_op.H(
torch.view_as_complex(
torch.nn.functional.softshrink(
torch.view_as_real(wavelet_op(fourier_op.H(data)[0])[0]), regularization_parameter
)
)
)[0]

n_iterations = 128
(pdhg_solution,) = pdhg(f=f, g=g, operator=operator, initial_values=initial_values, n_iterations=n_iterations)
torch.testing.assert_close(pdhg_solution, expected, rtol=5e-4, atol=5e-4)
Loading