Skip to content

Model attribute stop_evaluating does not work #19812

Closed
@Inv4lidn4m3

Description

@Inv4lidn4m3

Hello,

I am using Tensorflow 2.15 and the attribute stop_evaluating does not work.

In your model, if you try to call tf.print(self.stop_evaluating) it will result in an AttributeError. Same anywhere in the test_step() method or an evaluation callback. Also, if you try to set in a callback self.model.stop_evaluating = True, in on_test_batch_end() for instance, it will not stop the evaluation.

I tried to understand why.

As far as I can tell the implementation of the stop_training and stop_evaluation in the class TensorFlowTrainer is similar.

In fit():

self.stop_training = False
self.make_train_function()
callbacks.on_train_begin()
training_logs = None
logs = None
initial_epoch = self._initial_epoch or initial_epoch
for epoch in range(initial_epoch, epochs):
    self.reset_metrics()
    callbacks.on_epoch_begin(epoch)
    with epoch_iterator.catch_stop_iteration():
        for step, iterator in epoch_iterator.enumerate_epoch():
            callbacks.on_train_batch_begin(step)
            logs = self.train_function(iterator)
            logs = self._pythonify_logs(logs)
            callbacks.on_train_batch_end(step, logs)
            if self.stop_training:
                break

In evaluate():

self.make_test_function()
self.stop_evaluating = False
callbacks.on_test_begin()
logs = None
self.reset_metrics()
with epoch_iterator.catch_stop_iteration():
    for step, iterator in epoch_iterator.enumerate_epoch():
        callbacks.on_test_batch_begin(step)
        logs = self.test_function(iterator)
        logs = self._pythonify_logs(logs)
        callbacks.on_test_batch_end(step, logs)
        if self.stop_evaluating:
            break

The only difference I see is that the compile() method, from the base Trainer class, and inherited by TensorFlowTrainer only include the self.stop_training and not the self.stop_evaluating :

self.jit_compile = jit_compile
self.run_eagerly = run_eagerly
self.stop_training = False
self.compiled = True
self._loss_tracker = metrics_module.Mean(name="loss")
self.steps_per_execution = steps_per_execution

I don't know if this is the right lead as I am not an expert.

Thanks in advance!

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions