Skip to content

Commit 4fdf1ae

Browse files
committed
optim refactoring
* adding solver hook to get trainer, batch and losses befor optim * moving self._instace in Optimizer
1 parent 5966d03 commit 4fdf1ae

File tree

5 files changed

+71
-33
lines changed

5 files changed

+71
-33
lines changed

pina/optim/optimizer_interface.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Module for the PINA Optimizer."""
22

33
from abc import ABCMeta, abstractmethod
4+
from ..utils import check_consistency
45

56

67
class Optimizer(metaclass=ABCMeta):
@@ -9,15 +10,56 @@ class Optimizer(metaclass=ABCMeta):
910
should inherit form this class and implement the required methods.
1011
"""
1112

13+
def __init__(self):
14+
"""
15+
Initialization of the :class:`Optimizer` class.
16+
"""
17+
self._extra_optim_args = {}
18+
self._optimizer_instance = None
19+
1220
@property
13-
@abstractmethod
1421
def instance(self):
1522
"""
16-
Abstract property to retrieve the optimizer instance.
23+
Get the optimizer instance.
24+
25+
:return: The optimizer instance.
26+
:rtype: torch.optim.Optimizer
27+
"""
28+
return self._optimizer_instance
29+
30+
@instance.setter
31+
def instance(self, value):
32+
"""
33+
Set the optimizer instance.
34+
35+
:param Any value: The optimizer instance.
1736
"""
37+
self._optimizer_instance = value
1838

1939
@abstractmethod
2040
def hook(self):
2141
"""
2242
Abstract method to define the hook logic for the optimizer.
2343
"""
44+
45+
def get_optim_extra_args(self, trainer, batch, losses):
46+
"""
47+
Retrieve and set extra optimizer arguments from the optimizer instance.
48+
49+
This method calls the ``get_optim_extra_args`` method of the underlying
50+
optimizer instance (if it exists) and stores the resulting dictionary in
51+
:attr:`extra_optim_args` of the optimizer instance.
52+
53+
:param trainer: The training manager controlling the optimization loop.
54+
:type trainer: :class:`~pina.trainer.Trainer`
55+
:param dict batch: The current batch of data used for training.
56+
:param dict losses: Dictionary containing the computed loss values.
57+
:raises ValueError: If ``extra_get_optim_extra_argsargs_dict`` does not
58+
return a dictionary.
59+
"""
60+
if hasattr(self.instance, "get_optim_extra_args"):
61+
extra_args = self.instance.get_optim_extra_args(
62+
trainer=trainer, batch=batch, losses=losses
63+
)
64+
check_consistency(extra_args, dict)
65+
self.instance.extra_optim_args = extra_args

pina/optim/scheduler_interface.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,31 @@ class Scheduler(metaclass=ABCMeta):
99
inherit form this class and implement the required methods.
1010
"""
1111

12+
def __init__(self):
13+
"""
14+
Initialization of the :class:`Scheduler` class.
15+
"""
16+
self._schedule_instance = None
17+
self._last_lr = None
18+
1219
@property
13-
@abstractmethod
1420
def instance(self):
1521
"""
1622
Abstract property to retrieve the scheduler instance.
1723
"""
24+
return self._schedule_instance
25+
26+
@instance.setter
27+
def instance(self, value):
28+
"""
29+
Set the optimizer instance.
30+
31+
:param Any value: The optimizer instance.
32+
"""
33+
self._schedule_instance = value
1834

1935
@abstractmethod
2036
def hook(self):
2137
"""
2238
Abstract method to define the hook logic for the scheduler.
23-
"""
39+
"""

pina/optim/torch_optimizer.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,13 @@ def __init__(self, optimizer_class, **kwargs):
2222
`here <https://pytorch.org/docs/stable/optim.html#algorithms>`_.
2323
"""
2424
check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True)
25-
2625
self.optimizer_class = optimizer_class
2726
self.kwargs = kwargs
28-
self._optimizer_instance = None
2927

3028
def hook(self, parameters):
3129
"""
3230
Initialize the optimizer instance with the given parameters.
3331
3432
:param dict parameters: The parameters of the model to be optimized.
3533
"""
36-
self._optimizer_instance = self.optimizer_class(
37-
parameters, **self.kwargs
38-
)
39-
40-
@property
41-
def instance(self):
42-
"""
43-
Get the optimizer instance.
44-
45-
:return: The optimizer instance.
46-
:rtype: torch.optim.Optimizer
47-
"""
48-
return self._optimizer_instance
34+
self.instance = self.optimizer_class(parameters, **self.kwargs)

pina/optim/torch_scheduler.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def __init__(self, scheduler_class, **kwargs):
3131

3232
self.scheduler_class = scheduler_class
3333
self.kwargs = kwargs
34-
self._scheduler_instance = None
3534

3635
def hook(self, optimizer):
3736
"""
@@ -40,16 +39,6 @@ def hook(self, optimizer):
4039
:param dict parameters: The parameters of the optimizer.
4140
"""
4241
check_consistency(optimizer, Optimizer)
43-
self._scheduler_instance = self.scheduler_class(
42+
self.instance = self.scheduler_class(
4443
optimizer.instance, **self.kwargs
45-
)
46-
47-
@property
48-
def instance(self):
49-
"""
50-
Get the scheduler instance.
51-
52-
:return: The scheduelr instance.
53-
:rtype: torch.optim.LRScheduler
54-
"""
55-
return self._scheduler_instance
44+
)

pina/solver/solver.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,12 @@ def _optimization_cycle(self, batch, **kwargs):
243243
:rtype: dict
244244
"""
245245
# compute losses
246-
losses = self.optimization_cycle(batch)
246+
losses = self.optimization_cycle(batch, **kwargs)
247+
# attach trainer, batch, losses to optimizer extra args
248+
for optim in self._pina_optimizers:
249+
optim.get_optim_extra_args(
250+
trainer=self.trainer, batch=batch, losses=losses
251+
)
247252
# clamp unknown parameters in InverseProblem (if needed)
248253
self._clamp_params()
249254
# store log

0 commit comments

Comments
 (0)