Description
Hello,
I am using Tensorflow 2.15 and the attribute stop_evaluating does not work.
In your model, if you try to call tf.print(self.stop_evaluating)
it will result in an AttributeError. Same anywhere in the test_step()
method or an evaluation callback. Also, if you try to set in a callback self.model.stop_evaluating = True
, in on_test_batch_end()
for instance, it will not stop the evaluation.
I tried to understand why.
As far as I can tell the implementation of the stop_training
and stop_evaluation
in the class TensorFlowTrainer
is similar.
In fit()
:
self.stop_training = False
self.make_train_function()
callbacks.on_train_begin()
training_logs = None
logs = None
initial_epoch = self._initial_epoch or initial_epoch
for epoch in range(initial_epoch, epochs):
self.reset_metrics()
callbacks.on_epoch_begin(epoch)
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator.enumerate_epoch():
callbacks.on_train_batch_begin(step)
logs = self.train_function(iterator)
logs = self._pythonify_logs(logs)
callbacks.on_train_batch_end(step, logs)
if self.stop_training:
break
In evaluate()
:
self.make_test_function()
self.stop_evaluating = False
callbacks.on_test_begin()
logs = None
self.reset_metrics()
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator.enumerate_epoch():
callbacks.on_test_batch_begin(step)
logs = self.test_function(iterator)
logs = self._pythonify_logs(logs)
callbacks.on_test_batch_end(step, logs)
if self.stop_evaluating:
break
The only difference I see is that the compile()
method, from the base Trainer
class, and inherited by TensorFlowTrainer
only include the self.stop_training
and not the self.stop_evaluating
:
self.jit_compile = jit_compile
self.run_eagerly = run_eagerly
self.stop_training = False
self.compiled = True
self._loss_tracker = metrics_module.Mean(name="loss")
self.steps_per_execution = steps_per_execution
I don't know if this is the right lead as I am not an expert.
Thanks in advance!