Skip to content

Commit

Permalink
fix merge
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 committed Nov 12, 2024
1 parent fcf4442 commit e23054b
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/mrpro/algorithms/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from mrpro.algorithms.optimizers.lbfgs import lbfgs
from mrpro.algorithms.optimizers.pgd import pgd

__all__ = ["OptimizerStatus", "adam", "cg", "lbfgs", "pdg"]
__all__ = ["OptimizerStatus", "adam", "cg", "lbfgs", "pdg", "pgd"]
22 changes: 3 additions & 19 deletions src/mrpro/algorithms/optimizers/pgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,12 @@

import torch

from mrpro.operators.Functional import Functional


def grad_and_value(function, x, create_graph=False):
def inner(x):
if create_graph and x.requires_grad:
# shallow clone
xg = x.view_as(x)

xg = x.detach().requires_grad_(True)
(y,) = function(xg)

yg = y if isinstance(y, torch.Tensor) else y[0]
grad = torch.autograd.grad(yg, xg)[0]
return grad, y

return inner(x)
from mrpro.operators.Functional import Functional, ProximableFunctional


def pgd(
f: Functional,
g: Functional,
g: ProximableFunctional, # TODO: would it work with ProximableFunctionalSeparableSum?
initial_value: torch.Tensor,
stepsize: float = 1.0,
reg_parameter: float = 0.01,
Expand Down Expand Up @@ -97,7 +81,7 @@ def pgd(
for _ in range(max_iterations):
while stepsize > 1e-30:
# calculate the proximal gradient step
gradient, f_y = grad_and_value(f, y)
gradient, f_y = torch.func.grad_and_value(f, y)
(x,) = g.prox(y - stepsize * gradient, reg_parameter * stepsize)

if not backtracking:
Expand Down
4 changes: 2 additions & 2 deletions tests/algorithms/test_pgd.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Tests for the proximal gradient descent."""

import torch
from mrpro.algorithms.optimizers import pgd
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.operators import FastFourierOp
from mrpro.operators.functionals import L1Norm
from mrpro.operators.functionals import L2NormSquared
from mrpro.operators.functionals import L1Norm, L2NormSquared
from mrpro.phantoms import EllipsePhantom


Expand Down

0 comments on commit e23054b

Please sign in to comment.