diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py index 800b7ae2f..c47f63539 100644 --- a/t5x/checkpoints.py +++ b/t5x/checkpoints.py @@ -2435,7 +2435,12 @@ def save( _DATASET_KEY: DatasetArgs(self._dataset_iterator), } args = ocp.args.Composite(**args) - saved = self._manager.save(step, args=args, force=force) + saved = self._manager.save( + step, + args=args, + force=force, + metrics=self._options.metric_name_to_monitor, + ) # Record JAX monitoring events. end_time = time.time()