Skip to content

[BUG Report]: EarlyStopping Callback Stops Training Prematurely Regardless of Performance Improvement #1239

Open
@Patan77

Description

@Patan77

Description

Issue Description:
I am experiencing an issue with the EarlyStopping callback in TensorFlow.NET where the training process stops exactly after the number of epochs specified in the patience parameter, even if the monitored metric (in this case, accuracy) is still improving. This behavior is inconsistent with the expected functionality of the EarlyStopping callback, which should only stop training after the specified patience epochs if there is no improvement.

Steps to Reproduce:

Set up a model and compile it with accuracy as a metric.
Initialize the EarlyStopping callback with patience set to 10 and monitor set to "accuracy".
Train the model using model.fit() with a significant number of epochs (e.g., 1000).
Observe that training stops exactly after 10 epochs regardless of performance improvements.

Expected Behavior:
The training should continue past the patience number of epochs as long as the monitored metric (accuracy) is still improving. The training should only stop if there is no improvement in the metric for the duration specified by patience.

Actual Behavior:
The training stops exactly after the number of epochs specified by patience, even though the accuracy is still improving.

Code Snippet:

 int epochCount = 1000;

 var callbackParams = new CallbackParams
 {
     Model = model, 
     Epochs = epochCount, 
 };

 var earlyStoppingCallback = new EarlyStopping(
    parameters: callbackParams,
    monitor: "accuracy",
    min_delta: 0.0001f,
    patience: 10,
    mode: "max",
    baseline: 0.999f,
    restore_best_weights: true
);

 model.fit(inputs, labels, epochs: epochCount, batch_size: 100, callbacks: new List<ICallback> { customCallback , earlyStoppingCallback });

  public class CustomCallback : ICallback
  {
      public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
      {
          // Check if the logs contain the loss and accuracy keys

          if (epoch_logs.ContainsKey("loss"))
          {
              System.Diagnostics.Debug.WriteLine($"Epoch {epoch + 1}, Loss: {epoch_logs["loss"]}");
          }

          if (epoch_logs.ContainsKey("accuracy")) // Or "acc" depending on the version and setup
          {
              System.Diagnostics.Debug.WriteLine($"Epoch {epoch + 1}, Accuracy: {epoch_logs["accuracy"]}");
          }
      }
}

Logs / Output:

Epoch 1, Loss: 0.67718434
Epoch 1, Accuracy: 0.625
Epoch 2, Loss: 0.6478515
Epoch 2, Accuracy: 0.684
Epoch 3, Loss: 0.6122737
Epoch 3, Accuracy: 0.711
Epoch 4, Loss: 0.5662995
Epoch 4, Accuracy: 0.758
Epoch 5, Loss: 0.5094736
Epoch 5, Accuracy: 0.828
Epoch 6, Loss: 0.44897592
Epoch 6, Accuracy: 0.873
Epoch 7, Loss: 0.3877059
Epoch 7, Accuracy: 0.908
Epoch 8, Loss: 0.33612937
Epoch 8, Accuracy: 0.935
Epoch 9, Loss: 0.29862627
Epoch 9, Accuracy: 0.94
Epoch 10, Loss: 0.2639172
Epoch 10, Accuracy: 0.957
Model trained complete.

Reproduction Steps

No response

Known Workarounds

Not a workaround but putting baseline = 0.0f looks like it use the patience but still not working if you also want to set a "accuracy" threshold

Configuration and Other Information

SciSharp.TensorFlow.Redist version 2.16.0
TensorFlow.NET version 0.150.0
.NET 6.0
Operating System: Windows 10

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions