Skip to content

Commit e23054b

Browse files
committed
fix merge
1 parent fcf4442 commit e23054b

File tree

3 files changed

+6
-22
lines changed

3 files changed

+6
-22
lines changed

src/mrpro/algorithms/optimizers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
from mrpro.algorithms.optimizers.lbfgs import lbfgs
55
from mrpro.algorithms.optimizers.pgd import pgd
66

7-
__all__ = ["OptimizerStatus", "adam", "cg", "lbfgs", "pdg"]
7+
__all__ = ["OptimizerStatus", "adam", "cg", "lbfgs", "pdg", "pgd"]

src/mrpro/algorithms/optimizers/pgd.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,12 @@
44

55
import torch
66

7-
from mrpro.operators.Functional import Functional
8-
9-
10-
def grad_and_value(function, x, create_graph=False):
11-
def inner(x):
12-
if create_graph and x.requires_grad:
13-
# shallow clone
14-
xg = x.view_as(x)
15-
16-
xg = x.detach().requires_grad_(True)
17-
(y,) = function(xg)
18-
19-
yg = y if isinstance(y, torch.Tensor) else y[0]
20-
grad = torch.autograd.grad(yg, xg)[0]
21-
return grad, y
22-
23-
return inner(x)
7+
from mrpro.operators.Functional import Functional, ProximableFunctional
248

259

2610
def pgd(
2711
f: Functional,
28-
g: Functional,
12+
g: ProximableFunctional, # TODO: would it work with ProximableFunctionalSeparableSum?
2913
initial_value: torch.Tensor,
3014
stepsize: float = 1.0,
3115
reg_parameter: float = 0.01,
@@ -97,7 +81,7 @@ def pgd(
9781
for _ in range(max_iterations):
9882
while stepsize > 1e-30:
9983
# calculate the proximal gradient step
100-
gradient, f_y = grad_and_value(f, y)
84+
gradient, f_y = torch.func.grad_and_value(f, y)
10185
(x,) = g.prox(y - stepsize * gradient, reg_parameter * stepsize)
10286

10387
if not backtracking:

tests/algorithms/test_pgd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Tests for the proximal gradient descent."""
2+
23
import torch
34
from mrpro.algorithms.optimizers import pgd
45
from mrpro.data.SpatialDimension import SpatialDimension
56
from mrpro.operators import FastFourierOp
6-
from mrpro.operators.functionals import L1Norm
7-
from mrpro.operators.functionals import L2NormSquared
7+
from mrpro.operators.functionals import L1Norm, L2NormSquared
88
from mrpro.phantoms import EllipsePhantom
99

1010

0 commit comments

Comments
 (0)