Skip to content

Commit

Permalink
Merge branch 'develop' into rich_wires
Browse files Browse the repository at this point in the history
  • Loading branch information
ziofil authored Nov 14, 2024
2 parents 12b1118 + f3c226d commit 4780a68
Showing 1 changed file with 45 additions and 22 deletions.
67 changes: 45 additions & 22 deletions mrmustard/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from itertools import chain, groupby
from typing import List, Callable, Sequence, Union, Mapping, Dict
from mrmustard import math
from mrmustard import math, settings
from mrmustard.math.parameters import Constant, Variable
from mrmustard.training.callbacks import Callback
from mrmustard.training.progress_bar import ProgressBar
Expand Down Expand Up @@ -98,35 +98,58 @@ def minimize(
def _minimize(self, cost_fn, by_optimizing, max_steps, callbacks):
# finding out which parameters are trainable from the ops
trainable_params = self._get_trainable_params(by_optimizing)
if settings.PROGRESSBAR:
bar = ProgressBar(max_steps)
with bar:
self._optimization_loop(cost_fn, trainable_params, max_steps, callbacks, bar)
else:
self._optimization_loop(cost_fn, trainable_params, max_steps, callbacks)

def _optimization_loop(
self, cost_fn, trainable_params, max_steps, callbacks, progress_bar=None
):
"""Internal method that performs the main optimization loop.
Args:
cost_fn (Callable): The cost function to minimize
trainable_params (dict): Dictionary of trainable parameters
max_steps (int): Maximum number of optimization steps
callbacks (dict): Dictionary of callback functions to execute during optimization
progress_bar (ProgressBar, optional): Progress bar instance for displaying optimization progress.
If None, no progress will be displayed. Defaults to None.
Note:
This method maintains internal state in self.opt_history and self.callback_history,
tracking the optimization progress and callback results respectively.
"""
cost_fn_modified = False
orig_cost_fn = cost_fn

bar = ProgressBar(max_steps)
with bar:
while not self.should_stop(max_steps):
cost, grads = self.compute_loss_and_gradients(cost_fn, trainable_params.values())
while not self.should_stop(max_steps):
cost, grads = self.compute_loss_and_gradients(cost_fn, trainable_params.values())

trainables = {tag: (x, dx) for (tag, x), dx in zip(trainable_params.items(), grads)}
trainables = {tag: (x, dx) for (tag, x), dx in zip(trainable_params.items(), grads)}

if cost_fn_modified:
self.callback_history["orig_cost"].append(orig_cost_fn())
if cost_fn_modified:
self.callback_history["orig_cost"].append(orig_cost_fn())

new_cost_fn, new_grads = self._run_callbacks(
callbacks=callbacks,
cost_fn=cost_fn,
cost=cost,
trainables=trainables,
)
new_cost_fn, new_grads = self._run_callbacks(
callbacks=callbacks,
cost_fn=cost_fn,
cost=cost,
trainables=trainables,
)

self.apply_gradients(trainable_params.values(), new_grads or grads)
self.opt_history.append(cost)
bar.step(math.asnumpy(cost))
self.apply_gradients(trainable_params.values(), new_grads or grads)
self.opt_history.append(cost)
if progress_bar is not None:
progress_bar.step(math.asnumpy(cost))

if callable(new_cost_fn):
cost_fn = new_cost_fn
if not cost_fn_modified:
cost_fn_modified = True
self.callback_history["orig_cost"] = self.opt_history.copy()
if callable(new_cost_fn):
cost_fn = new_cost_fn
if not cost_fn_modified:
cost_fn_modified = True
self.callback_history["orig_cost"] = self.opt_history.copy()

def apply_gradients(self, trainable_params, grads):
"""Apply gradients to variables.
Expand Down

0 comments on commit 4780a68

Please sign in to comment.